1
0
mirror of https://github.com/zu1k/nali.git synced 2025-01-22 13:19:02 +08:00

Merge pull request #102 from linbuxiao/up

feat: make update module flex
This commit is contained in:
zu1k 2022-05-05 17:03:45 +08:00 committed by GitHub
commit 2cba310c34
2 changed files with 64 additions and 22 deletions

View File

@ -2,21 +2,28 @@ package cmd
import ( import (
"github.com/zu1k/nali/internal/db" "github.com/zu1k/nali/internal/db"
"strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// updateCmd represents the update command // updateCmd represents the update command
var updateCmd = &cobra.Command{ var updateCmd = &cobra.Command{
Use: "update", Use: "update",
Short: "update chunzhen ip database", Short: "update chunzhen, zxipv6, ip2region ip database and cdn",
Long: `update chunzhen ip database`, Long: `update chunzhen, zxipv6, ip2region ip database and cdn. Use commas to separate`,
Example: "nali update --db chunzhen,cdn",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
db.UpdateAllDB() DBs, _ := cmd.Flags().GetString("db")
var DBNameArray []string
if DBs != "" {
DBNameArray = strings.Split(DBs, ",")
}
db.UpdateDB(DBNameArray...)
}, },
} }
func init() { func init() {
rootCmd.AddCommand(updateCmd) rootCmd.AddCommand(updateCmd)
rootCmd.PersistentFlags().String("db", "", "choose db you want to update")
} }

View File

@ -104,27 +104,62 @@ func GetIPDBbyName(name string) (dbif.DB, error) {
} }
} }
func UpdateAllDB() { func getDBInfoMap() map[string]func() error {
log.Println("正在下载最新 纯真 IPv4数据库...") return map[string]func() error{
_, err := qqwry.Download(QQWryPath) "chunzhen": func() error {
if err != nil { log.Println("正在下载最新 纯真 IPv4数据库...")
log.Fatalln("数据库 QQWry 下载失败:", err) _, err := qqwry.Download(QQWryPath)
if err != nil {
log.Fatalln("数据库 QQWry 下载失败:", err)
}
return err
},
"zxipv6": func() error {
log.Println("正在下载最新 ZX IPv6数据库...")
_, err := zxipv6wry.Download(ZXIPv6WryPath)
if err != nil {
log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err)
}
return err
},
"ip2region": func() error {
log.Println("正在下载最新 Ip2Region 数据库...")
_, err := ip2region.Download(Ip2RegionPath)
if err != nil {
log.Fatalln("数据库 Ip2Region 下载失败:", err)
}
return err
},
"cdn": func() error {
log.Println("正在下载最新 CDN服务提供商数据库...")
_, err := cdn.Download(CDNPath)
if err != nil {
log.Fatalln("数据库 CDN 下载失败:", err)
}
return err
},
} }
}
log.Println("正在下载最新 ZX IPv6数据库...") func UpdateDB(dbName ...string) {
_, err = zxipv6wry.Download(ZXIPv6WryPath) dbInfo := getDBInfoMap()
if err != nil { isAll := false
log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err) if len(dbName) == 0 {
isAll = true
} }
_, err = ip2region.Download(Ip2RegionPath) keySet := make(map[string]struct{})
if err != nil { for _, v := range dbName {
log.Fatalln("数据库 Ip2Region 下载失败:", err) keySet[v] = struct{}{}
} }
for key, action := range dbInfo {
log.Println("正在下载最新 CDN服务提供商数据库...") _, ok := keySet[key]
_, err = cdn.Download(CDNPath) if !isAll && !ok {
if err != nil { continue
log.Fatalln("数据库 CDN 下载失败:", err) }
if err := action(); err != nil {
// keep loop
continue
}
} }
} }