diff --git a/cmd/update.go b/cmd/update.go index ee27e07..4cedba5 100644 --- a/cmd/update.go +++ b/cmd/update.go @@ -2,21 +2,28 @@ package cmd import ( "github.com/zu1k/nali/internal/db" + "strings" "github.com/spf13/cobra" ) // updateCmd represents the update command var updateCmd = &cobra.Command{ - Use: "update", - Short: "update chunzhen ip database", - Long: `update chunzhen ip database`, + Use: "update", + Short: "update chunzhen, zxipv6, ip2region ip database and cdn", + Long: `update chunzhen, zxipv6, ip2region ip database and cdn. Use commas to separate`, + Example: "nali update --db chunzhen,cdn", Run: func(cmd *cobra.Command, args []string) { - db.UpdateAllDB() - + DBs, _ := cmd.Flags().GetString("db") + var DBNameArray []string + if DBs != "" { + DBNameArray = strings.Split(DBs, ",") + } + db.UpdateDB(DBNameArray...) }, } func init() { rootCmd.AddCommand(updateCmd) + rootCmd.PersistentFlags().String("db", "", "choose db you want to update") } diff --git a/internal/db/db.go b/internal/db/db.go index 617e5ce..e890318 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -104,27 +104,62 @@ func GetIPDBbyName(name string) (dbif.DB, error) { } } -func UpdateAllDB() { - log.Println("正在下载最新 纯真 IPv4数据库...") - _, err := qqwry.Download(QQWryPath) - if err != nil { - log.Fatalln("数据库 QQWry 下载失败:", err) +func getDBInfoMap() map[string]func() error { + return map[string]func() error{ + "chunzhen": func() error { + log.Println("正在下载最新 纯真 IPv4数据库...") + _, err := qqwry.Download(QQWryPath) + if err != nil { + log.Fatalln("数据库 QQWry 下载失败:", err) + } + return err + }, + "zxipv6": func() error { + log.Println("正在下载最新 ZX IPv6数据库...") + _, err := zxipv6wry.Download(ZXIPv6WryPath) + if err != nil { + log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err) + } + return err + }, + "ip2region": func() error { + log.Println("正在下载最新 Ip2Region 数据库...") + _, err := ip2region.Download(Ip2RegionPath) + if err != nil { + log.Fatalln("数据库 Ip2Region 下载失败:", err) + } + return err + }, + "cdn": func() error { + log.Println("正在下载最新 CDN服务提供商数据库...") + _, err := cdn.Download(CDNPath) + if err != nil { + log.Fatalln("数据库 CDN 下载失败:", err) + } + return err + }, } +} - log.Println("正在下载最新 ZX IPv6数据库...") - _, err = zxipv6wry.Download(ZXIPv6WryPath) - if err != nil { - log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err) +func UpdateDB(dbName ...string) { + dbInfo := getDBInfoMap() + isAll := false + if len(dbName) == 0 { + isAll = true } - _, err = ip2region.Download(Ip2RegionPath) - if err != nil { - log.Fatalln("数据库 Ip2Region 下载失败:", err) + keySet := make(map[string]struct{}) + for _, v := range dbName { + keySet[v] = struct{}{} } - - log.Println("正在下载最新 CDN服务提供商数据库...") - _, err = cdn.Download(CDNPath) - if err != nil { - log.Fatalln("数据库 CDN 下载失败:", err) + for key, action := range dbInfo { + _, ok := keySet[key] + if !isAll && !ok { + continue + } + if err := action(); err != nil { + // keep loop + continue + } } }