mirror of
https://github.com/zu1k/nali.git
synced 2025-01-22 13:19:02 +08:00
Merge pull request #102 from linbuxiao/up
feat: make update module flex
This commit is contained in:
commit
2cba310c34
@ -2,21 +2,28 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/zu1k/nali/internal/db"
|
"github.com/zu1k/nali/internal/db"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
// updateCmd represents the update command
|
// updateCmd represents the update command
|
||||||
var updateCmd = &cobra.Command{
|
var updateCmd = &cobra.Command{
|
||||||
Use: "update",
|
Use: "update",
|
||||||
Short: "update chunzhen ip database",
|
Short: "update chunzhen, zxipv6, ip2region ip database and cdn",
|
||||||
Long: `update chunzhen ip database`,
|
Long: `update chunzhen, zxipv6, ip2region ip database and cdn. Use commas to separate`,
|
||||||
|
Example: "nali update --db chunzhen,cdn",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
db.UpdateAllDB()
|
DBs, _ := cmd.Flags().GetString("db")
|
||||||
|
var DBNameArray []string
|
||||||
|
if DBs != "" {
|
||||||
|
DBNameArray = strings.Split(DBs, ",")
|
||||||
|
}
|
||||||
|
db.UpdateDB(DBNameArray...)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rootCmd.AddCommand(updateCmd)
|
rootCmd.AddCommand(updateCmd)
|
||||||
|
rootCmd.PersistentFlags().String("db", "", "choose db you want to update")
|
||||||
}
|
}
|
||||||
|
@ -104,27 +104,62 @@ func GetIPDBbyName(name string) (dbif.DB, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAllDB() {
|
func getDBInfoMap() map[string]func() error {
|
||||||
log.Println("正在下载最新 纯真 IPv4数据库...")
|
return map[string]func() error{
|
||||||
_, err := qqwry.Download(QQWryPath)
|
"chunzhen": func() error {
|
||||||
if err != nil {
|
log.Println("正在下载最新 纯真 IPv4数据库...")
|
||||||
log.Fatalln("数据库 QQWry 下载失败:", err)
|
_, err := qqwry.Download(QQWryPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("数据库 QQWry 下载失败:", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
"zxipv6": func() error {
|
||||||
|
log.Println("正在下载最新 ZX IPv6数据库...")
|
||||||
|
_, err := zxipv6wry.Download(ZXIPv6WryPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
"ip2region": func() error {
|
||||||
|
log.Println("正在下载最新 Ip2Region 数据库...")
|
||||||
|
_, err := ip2region.Download(Ip2RegionPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("数据库 Ip2Region 下载失败:", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
"cdn": func() error {
|
||||||
|
log.Println("正在下载最新 CDN服务提供商数据库...")
|
||||||
|
_, err := cdn.Download(CDNPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("数据库 CDN 下载失败:", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Println("正在下载最新 ZX IPv6数据库...")
|
func UpdateDB(dbName ...string) {
|
||||||
_, err = zxipv6wry.Download(ZXIPv6WryPath)
|
dbInfo := getDBInfoMap()
|
||||||
if err != nil {
|
isAll := false
|
||||||
log.Fatalln("数据库 ZXIPv6Wry 下载失败:", err)
|
if len(dbName) == 0 {
|
||||||
|
isAll = true
|
||||||
}
|
}
|
||||||
_, err = ip2region.Download(Ip2RegionPath)
|
keySet := make(map[string]struct{})
|
||||||
if err != nil {
|
for _, v := range dbName {
|
||||||
log.Fatalln("数据库 Ip2Region 下载失败:", err)
|
keySet[v] = struct{}{}
|
||||||
}
|
}
|
||||||
|
for key, action := range dbInfo {
|
||||||
log.Println("正在下载最新 CDN服务提供商数据库...")
|
_, ok := keySet[key]
|
||||||
_, err = cdn.Download(CDNPath)
|
if !isAll && !ok {
|
||||||
if err != nil {
|
continue
|
||||||
log.Fatalln("数据库 CDN 下载失败:", err)
|
}
|
||||||
|
if err := action(); err != nil {
|
||||||
|
// keep loop
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user