1
0
mirror of https://github.com/zu1k/nali.git synced 2025-01-22 13:19:02 +08:00

Merge pull request #102 from linbuxiao/up

feat: make update module flex
This commit is contained in:
zu1k 2022-05-05 17:03:45 +08:00 committed by GitHub
commit 2cba310c34
2 changed files with 64 additions and 22 deletions

View File

@ -2,6 +2,7 @@ package cmd
import (
"github.com/zu1k/nali/internal/db"
"strings"
"github.com/spf13/cobra"
)
@ -9,14 +10,20 @@ import (
// updateCmd represents the update command
var updateCmd = &cobra.Command{
Use: "update",
Short: "update chunzhen ip database",
Long: `update chunzhen ip database`,
Short: "update chunzhen, zxipv6, ip2region ip database and cdn",
Long: `update chunzhen, zxipv6, ip2region ip database and cdn. Use commas to separate`,
Example: "nali update --db chunzhen,cdn",
Run: func(cmd *cobra.Command, args []string) {
db.UpdateAllDB()
DBs, _ := cmd.Flags().GetString("db")
var DBNameArray []string
if DBs != "" {
DBNameArray = strings.Split(DBs, ",")
}
db.UpdateDB(DBNameArray...)
},
}
func init() {
rootCmd.AddCommand(updateCmd)
rootCmd.PersistentFlags().String("db", "", "choose db you want to update")
}

View File

@ -104,28 +104,63 @@ func GetIPDBbyName(name string) (dbif.DB, error) {
}
}
func UpdateAllDB() {
func getDBInfoMap() map[string]func() error {
return map[string]func() error{
"chunzhen": func() error {
log.Println("正在下载最新 纯真 IPv4数据库...")
_, err := qqwry.Download(QQWryPath)
if err != nil {
log.Fatalln("数据库 QQWry 下载失败:", err)
}
return err
},
"zxipv6": func() error {
log.Println("正在下载最新 ZX IPv6数据库...")
_, err = zxipv6wry.Download(ZXIPv6WryPath)
_, err := zxipv6wry.Download(ZXIPv6WryPath)
if err != nil {
log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err)
}
_, err = ip2region.Download(Ip2RegionPath)
return err
},
"ip2region": func() error {
log.Println("正在下载最新 Ip2Region 数据库...")
_, err := ip2region.Download(Ip2RegionPath)
if err != nil {
log.Fatalln("数据库 Ip2Region 下载失败:", err)
}
return err
},
"cdn": func() error {
log.Println("正在下载最新 CDN服务提供商数据库...")
_, err = cdn.Download(CDNPath)
_, err := cdn.Download(CDNPath)
if err != nil {
log.Fatalln("数据库 CDN 下载失败:", err)
}
return err
},
}
}
func UpdateDB(dbName ...string) {
dbInfo := getDBInfoMap()
isAll := false
if len(dbName) == 0 {
isAll = true
}
keySet := make(map[string]struct{})
for _, v := range dbName {
keySet[v] = struct{}{}
}
for key, action := range dbInfo {
_, ok := keySet[key]
if !isAll && !ok {
continue
}
if err := action(); err != nil {
// keep loop
continue
}
}
}
func Find(typ dbif.QueryType, query string) string {