1
0
mirror of https://github.com/zu1k/nali.git synced 2025-01-22 13:19:02 +08:00

feat: check db data after download

This commit is contained in:
zu1k 2022-10-24 09:51:19 +08:00
parent fc9df8dcf0
commit b0bd2771d6
No known key found for this signature in database
GPG Key ID: AE381A8FB1EF2CC8
5 changed files with 69 additions and 20 deletions

View File

@ -1,6 +1,8 @@
package db
import (
"errors"
"github.com/zu1k/nali/pkg/qqwry"
"log"
"strings"
"time"
@ -33,6 +35,11 @@ var DbNameListForUpdate = []string{
"cdn",
}
var DbCheckFunc = map[Format]func([]byte) bool{
FormatQQWry: qqwry.CheckFile,
FormatZXIPv6Wry: zxipv6wry.CheckFile,
}
func getUpdateFuncByName(name string) (func() error, string) {
name = strings.TrimSpace(name)
if db := getDbByName(name); db != nil {
@ -40,12 +47,20 @@ func getUpdateFuncByName(name string) (func() error, string) {
if len(db.DownloadUrls) > 0 {
return func() error {
log.Printf("正在下载最新 %s 数据库...\n", db.Name)
_, err := download.Download(db.File, db.DownloadUrls...)
data, err := download.Download(db.File, db.DownloadUrls...)
if err != nil {
log.Printf("%s 数据库下载失败: %s\n", db.Name, db.File)
log.Printf("%s 数据库下载失败,请手动下载解压后保存到本地: %s \n", db.Name, db.File)
log.Println("下载链接:", db.DownloadUrls)
log.Println("error:", err)
return err
} else {
if check, ok := DbCheckFunc[db.Format]; ok {
if !check(data) {
log.Printf("%s 数据库下载失败,请手动下载解压后保存到本地: %s \n", db.Name, db.File)
log.Println("下载链接:", db.DownloadUrls)
return errors.New("数据库内容出错")
}
}
log.Printf("%s 数据库下载成功: %s\n", db.Name, db.File)
return nil
}

View File

@ -1,23 +1,20 @@
package download
import (
"log"
"errors"
"github.com/zu1k/nali/pkg/common"
)
func Download(filePath string, urls ...string) (data []byte, err error) {
_ = urls[0]
if len(urls) == 0 {
return nil, errors.New("未指定下载 url")
}
data, err = common.GetHttpClient().Get(urls...)
if err != nil {
log.Printf("文件下载失败,请手动下载解压后保存到本地: %s \n", filePath)
log.Println("下载链接:", urls)
return
}
if err := common.SaveFile(filePath, data); err == nil {
log.Println("文件下载成功:", filePath)
}
err = common.SaveFile(filePath, data)
return
}

View File

@ -45,7 +45,7 @@ func NewQQwry(filePath string) (*QQwry, error) {
}
}
if len(fileData) < 8 {
if !CheckFile(fileData) {
log.Fatalln("纯真 IP 库存在错误,请重新下载")
}
@ -53,10 +53,6 @@ func NewQQwry(filePath string) (*QQwry, error) {
start := binary.LittleEndian.Uint32(header[:4])
end := binary.LittleEndian.Uint32(header[4:])
if uint32(len(fileData)) < end+7 {
log.Fatalln("纯真 IP 库存在错误,请重新下载")
}
return &QQwry{
IPDB: wry.IPDB[uint32]{
Data: fileData,
@ -90,3 +86,19 @@ func (db QQwry) Find(query string, params ...string) (result fmt.Stringer, err e
reader.Parse(offset + 4)
return reader.Result.DecodeGBK(), nil
}
func CheckFile(data []byte) bool {
if len(data) < 8 {
return false
}
header := data[0:8]
start := binary.LittleEndian.Uint32(header[:4])
end := binary.LittleEndian.Uint32(header[4:])
if start >= end || uint32(len(data)) < end+7 {
return false
}
return true
}

View File

@ -1,6 +1,7 @@
package zxipv6wry
import (
"errors"
"io"
"io/ioutil"
"log"
@ -18,6 +19,12 @@ func Download(filePath ...string) (data []byte, err error) {
return
}
if !CheckFile(data) {
log.Printf("ZX IPv6数据库下载出错请手动下载解压后保存到本地: %s \n", filePath)
log.Println("下载链接: https://ip.zxinc.org/ip.7z")
return nil, errors.New("数据库下载内容出错")
}
if len(filePath) == 1 {
if err := common.SaveFile(filePath[0], data); err == nil {
log.Println("已将最新的 ZX IPv6数据库 保存到本地:", filePath)

View File

@ -39,7 +39,7 @@ func NewZXwry(filePath string) (*ZXwry, error) {
}
}
if len(fileData) < 24 {
if !CheckFile(fileData) {
log.Fatalln("ZX IPv6数据库存在错误请重新下载")
}
@ -51,10 +51,6 @@ func NewZXwry(filePath string) (*ZXwry, error) {
counts := binary.LittleEndian.Uint64(header[8:16])
end := start + counts*11
if uint64(len(fileData)) < end {
log.Fatalln("ZX IPv6数据库存在错误请重新下载")
}
return &ZXwry{
IPDB: wry.IPDB[uint64]{
Data: fileData,
@ -85,3 +81,25 @@ func (db *ZXwry) Find(query string, _ ...string) (result fmt.Stringer, err error
reader.Parse(offset)
return reader.Result, nil
}
func CheckFile(data []byte) bool {
if len(data) < 4 {
return false
}
if string(data[:4]) != "IPDB" {
return false
}
if len(data) < 24 {
return false
}
header := data[:24]
start := binary.LittleEndian.Uint64(header[16:24])
counts := binary.LittleEndian.Uint64(header[8:16])
end := start + counts*11
if start >= end || uint64(len(data)) < end {
return false
}
return true
}