diff --git a/internal/db/update.go b/internal/db/update.go index d7320c2..9154dd5 100644 --- a/internal/db/update.go +++ b/internal/db/update.go @@ -1,6 +1,8 @@ package db import ( + "errors" + "github.com/zu1k/nali/pkg/qqwry" "log" "strings" "time" @@ -33,6 +35,11 @@ var DbNameListForUpdate = []string{ "cdn", } +var DbCheckFunc = map[Format]func([]byte) bool{ + FormatQQWry: qqwry.CheckFile, + FormatZXIPv6Wry: zxipv6wry.CheckFile, +} + func getUpdateFuncByName(name string) (func() error, string) { name = strings.TrimSpace(name) if db := getDbByName(name); db != nil { @@ -40,12 +47,20 @@ func getUpdateFuncByName(name string) (func() error, string) { if len(db.DownloadUrls) > 0 { return func() error { log.Printf("正在下载最新 %s 数据库...\n", db.Name) - _, err := download.Download(db.File, db.DownloadUrls...) + data, err := download.Download(db.File, db.DownloadUrls...) if err != nil { - log.Printf("%s 数据库下载失败: %s\n", db.Name, db.File) + log.Printf("%s 数据库下载失败,请手动下载解压后保存到本地: %s \n", db.Name, db.File) + log.Println("下载链接:", db.DownloadUrls) log.Println("error:", err) return err } else { + if check, ok := DbCheckFunc[db.Format]; ok { + if !check(data) { + log.Printf("%s 数据库下载失败,请手动下载解压后保存到本地: %s \n", db.Name, db.File) + log.Println("下载链接:", db.DownloadUrls) + return errors.New("数据库内容出错") + } + } log.Printf("%s 数据库下载成功: %s\n", db.Name, db.File) return nil } diff --git a/pkg/download/download.go b/pkg/download/download.go index 5719c45..fc1e37e 100644 --- a/pkg/download/download.go +++ b/pkg/download/download.go @@ -1,23 +1,20 @@ package download import ( - "log" - + "errors" "github.com/zu1k/nali/pkg/common" ) func Download(filePath string, urls ...string) (data []byte, err error) { - _ = urls[0] + if len(urls) == 0 { + return nil, errors.New("未指定下载 url") + } data, err = common.GetHttpClient().Get(urls...) if err != nil { - log.Printf("文件下载失败,请手动下载解压后保存到本地: %s \n", filePath) - log.Println("下载链接:", urls) return } - if err := common.SaveFile(filePath, data); err == nil { - log.Println("文件下载成功:", filePath) - } + err = common.SaveFile(filePath, data) return } diff --git a/pkg/qqwry/qqwry.go b/pkg/qqwry/qqwry.go index 9b59001..d5e3dc1 100644 --- a/pkg/qqwry/qqwry.go +++ b/pkg/qqwry/qqwry.go @@ -45,7 +45,7 @@ func NewQQwry(filePath string) (*QQwry, error) { } } - if len(fileData) < 8 { + if !CheckFile(fileData) { log.Fatalln("纯真 IP 库存在错误,请重新下载") } @@ -53,10 +53,6 @@ func NewQQwry(filePath string) (*QQwry, error) { start := binary.LittleEndian.Uint32(header[:4]) end := binary.LittleEndian.Uint32(header[4:]) - if uint32(len(fileData)) < end+7 { - log.Fatalln("纯真 IP 库存在错误,请重新下载") - } - return &QQwry{ IPDB: wry.IPDB[uint32]{ Data: fileData, @@ -90,3 +86,19 @@ func (db QQwry) Find(query string, params ...string) (result fmt.Stringer, err e reader.Parse(offset + 4) return reader.Result.DecodeGBK(), nil } + +func CheckFile(data []byte) bool { + if len(data) < 8 { + return false + } + + header := data[0:8] + start := binary.LittleEndian.Uint32(header[:4]) + end := binary.LittleEndian.Uint32(header[4:]) + + if start >= end || uint32(len(data)) < end+7 { + return false + } + + return true +} diff --git a/pkg/zxipv6wry/update.go b/pkg/zxipv6wry/update.go index f25ca57..1df1d80 100644 --- a/pkg/zxipv6wry/update.go +++ b/pkg/zxipv6wry/update.go @@ -1,6 +1,7 @@ package zxipv6wry import ( + "errors" "io" "io/ioutil" "log" @@ -18,6 +19,12 @@ func Download(filePath ...string) (data []byte, err error) { return } + if !CheckFile(data) { + log.Printf("ZX IPv6数据库下载出错,请手动下载解压后保存到本地: %s \n", filePath) + log.Println("下载链接: https://ip.zxinc.org/ip.7z") + return nil, errors.New("数据库下载内容出错") + } + if len(filePath) == 1 { if err := common.SaveFile(filePath[0], data); err == nil { log.Println("已将最新的 ZX IPv6数据库 保存到本地:", filePath) diff --git a/pkg/zxipv6wry/zxipv6wry.go b/pkg/zxipv6wry/zxipv6wry.go index d9650f7..9816c6e 100644 --- a/pkg/zxipv6wry/zxipv6wry.go +++ b/pkg/zxipv6wry/zxipv6wry.go @@ -39,7 +39,7 @@ func NewZXwry(filePath string) (*ZXwry, error) { } } - if len(fileData) < 24 { + if !CheckFile(fileData) { log.Fatalln("ZX IPv6数据库存在错误,请重新下载") } @@ -51,10 +51,6 @@ func NewZXwry(filePath string) (*ZXwry, error) { counts := binary.LittleEndian.Uint64(header[8:16]) end := start + counts*11 - if uint64(len(fileData)) < end { - log.Fatalln("ZX IPv6数据库存在错误,请重新下载") - } - return &ZXwry{ IPDB: wry.IPDB[uint64]{ Data: fileData, @@ -85,3 +81,25 @@ func (db *ZXwry) Find(query string, _ ...string) (result fmt.Stringer, err error reader.Parse(offset) return reader.Result, nil } + +func CheckFile(data []byte) bool { + if len(data) < 4 { + return false + } + if string(data[:4]) != "IPDB" { + return false + } + + if len(data) < 24 { + return false + } + header := data[:24] + start := binary.LittleEndian.Uint64(header[16:24]) + counts := binary.LittleEndian.Uint64(header[8:16]) + end := start + counts*11 + if start >= end || uint64(len(data)) < end { + return false + } + + return true +}