mirror of
https://github.com/zu1k/nali.git
synced 2025-01-22 13:19:02 +08:00
Merge pull request #102 from linbuxiao/up
feat: make update module flex
This commit is contained in:
commit
2cba310c34
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"github.com/zu1k/nali/internal/db"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@ -9,14 +10,20 @@ import (
|
||||
// updateCmd represents the update command
|
||||
var updateCmd = &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "update chunzhen ip database",
|
||||
Long: `update chunzhen ip database`,
|
||||
Short: "update chunzhen, zxipv6, ip2region ip database and cdn",
|
||||
Long: `update chunzhen, zxipv6, ip2region ip database and cdn. Use commas to separate`,
|
||||
Example: "nali update --db chunzhen,cdn",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
db.UpdateAllDB()
|
||||
|
||||
DBs, _ := cmd.Flags().GetString("db")
|
||||
var DBNameArray []string
|
||||
if DBs != "" {
|
||||
DBNameArray = strings.Split(DBs, ",")
|
||||
}
|
||||
db.UpdateDB(DBNameArray...)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(updateCmd)
|
||||
rootCmd.PersistentFlags().String("db", "", "choose db you want to update")
|
||||
}
|
||||
|
@ -104,28 +104,63 @@ func GetIPDBbyName(name string) (dbif.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateAllDB() {
|
||||
func getDBInfoMap() map[string]func() error {
|
||||
return map[string]func() error{
|
||||
"chunzhen": func() error {
|
||||
log.Println("正在下载最新 纯真 IPv4数据库...")
|
||||
_, err := qqwry.Download(QQWryPath)
|
||||
if err != nil {
|
||||
log.Fatalln("数据库 QQWry 下载失败:", err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
"zxipv6": func() error {
|
||||
log.Println("正在下载最新 ZX IPv6数据库...")
|
||||
_, err = zxipv6wry.Download(ZXIPv6WryPath)
|
||||
_, err := zxipv6wry.Download(ZXIPv6WryPath)
|
||||
if err != nil {
|
||||
log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err)
|
||||
}
|
||||
_, err = ip2region.Download(Ip2RegionPath)
|
||||
return err
|
||||
},
|
||||
"ip2region": func() error {
|
||||
log.Println("正在下载最新 Ip2Region 数据库...")
|
||||
_, err := ip2region.Download(Ip2RegionPath)
|
||||
if err != nil {
|
||||
log.Fatalln("数据库 Ip2Region 下载失败:", err)
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
"cdn": func() error {
|
||||
log.Println("正在下载最新 CDN服务提供商数据库...")
|
||||
_, err = cdn.Download(CDNPath)
|
||||
_, err := cdn.Download(CDNPath)
|
||||
if err != nil {
|
||||
log.Fatalln("数据库 CDN 下载失败:", err)
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateDB(dbName ...string) {
|
||||
dbInfo := getDBInfoMap()
|
||||
isAll := false
|
||||
if len(dbName) == 0 {
|
||||
isAll = true
|
||||
}
|
||||
keySet := make(map[string]struct{})
|
||||
for _, v := range dbName {
|
||||
keySet[v] = struct{}{}
|
||||
}
|
||||
for key, action := range dbInfo {
|
||||
_, ok := keySet[key]
|
||||
if !isAll && !ok {
|
||||
continue
|
||||
}
|
||||
if err := action(); err != nil {
|
||||
// keep loop
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Find(typ dbif.QueryType, query string) string {
|
||||
|
Loading…
Reference in New Issue
Block a user