diff --git a/go.mod b/go.mod index 30d9a4d..02bff9e 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.18 require ( github.com/fatih/color v1.13.0 github.com/ipipdotnet/ipdb-go v1.3.1 - github.com/oschwald/geoip2-golang v1.5.0 + github.com/lionsoul2014/ip2region v2.2.0-release+incompatible + github.com/oschwald/geoip2-golang v1.6.1 github.com/saracen/go7z v0.0.0-20191010121135-9c09b6bd7fda github.com/spf13/cobra v1.3.0 golang.org/x/text v0.3.7 @@ -13,7 +14,6 @@ require ( require ( github.com/inconshreveable/mousetrap v1.0.0 // indirect - github.com/lionsoul2014/ip2region v2.2.0-release+incompatible // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/oschwald/maxminddb-golang v1.8.0 // indirect @@ -21,6 +21,6 @@ require ( github.com/saracen/solidblock v0.0.0-20190426153529-45df20abab6f // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/ulikunitz/xz v0.5.10 // indirect - golang.org/x/sys v0.0.0-20211205182925-97ca703d548d // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/tools v0.1.7 // indirect ) diff --git a/go.sum b/go.sum index 90096f9..35f5988 100644 --- a/go.sum +++ b/go.sum @@ -277,8 +277,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/oschwald/geoip2-golang v1.5.0 h1:igg2yQIrrcRccB1ytFXqBfOHCjXWIoMv85lVJ1ONZzw= -github.com/oschwald/geoip2-golang v1.5.0/go.mod h1:xdvYt5xQzB8ORWFqPnqMwZpCpgNagttWdoZLlJQzg7s= +github.com/oschwald/geoip2-golang v1.6.1 h1:GKxT3yaWWNXSb7vj6D7eoJBns+lGYgx08QO0UcNm0YY= +github.com/oschwald/geoip2-golang v1.6.1/go.mod h1:xdvYt5xQzB8ORWFqPnqMwZpCpgNagttWdoZLlJQzg7s= github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk= github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -540,8 +540,9 @@ golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211205182925-97ca703d548d h1:FjkYO/PPp4Wi0EAUOVLxePm7qVW4r4ctbWpURyuOD0E= golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/db/db.go b/internal/db/db.go index 3962208..2ac381d 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,16 +1,16 @@ package db import ( + "log" "os" "path/filepath" "strings" - "github.com/zu1k/nali/pkg/ip2region" - "github.com/zu1k/nali/internal/constant" "github.com/zu1k/nali/pkg/cdn" "github.com/zu1k/nali/pkg/dbif" "github.com/zu1k/nali/pkg/geoip" + "github.com/zu1k/nali/pkg/ip2region" "github.com/zu1k/nali/pkg/ipip" "github.com/zu1k/nali/pkg/qqwry" "github.com/zu1k/nali/pkg/zxipv6wry" @@ -51,38 +51,44 @@ func GetDB(typ dbif.QueryType) (db dbif.DB) { return db } + var err error + switch typ { case dbif.TypeIPv4: if IPv4DBSelected != "" { - db = GetIPDBbyName(IPv4DBSelected) + db, err = GetIPDBbyName(IPv4DBSelected) } else { if Language == "zh-CN" { - db = qqwry.NewQQwry(QQWryPath) + db, err = qqwry.NewQQwry(QQWryPath) } else { - db = geoip.NewGeoIP(GeoLite2CityPath) + db, err = geoip.NewGeoIP(GeoLite2CityPath) } } case dbif.TypeIPv6: if IPv6DBSelected != "" { - db = GetIPDBbyName(IPv6DBSelected) + db, err = GetIPDBbyName(IPv6DBSelected) } else { if Language == "zh-CN" { - db = zxipv6wry.NewZXwry(ZXIPv6WryPath) + db, err = zxipv6wry.NewZXwry(ZXIPv6WryPath) } else { - db = geoip.NewGeoIP(GeoLite2CityPath) + db, err = geoip.NewGeoIP(GeoLite2CityPath) } } case dbif.TypeDomain: - db = cdn.NewCDN(CDNPath) + db, err = cdn.NewCDN(CDNPath) default: panic("Query type not supported!") } + if err != nil || db == nil { + log.Fatalln("Database init failed:", err) + } + dbCache[typ] = db return } -func GetIPDBbyName(name string) (db dbif.DB) { +func GetIPDBbyName(name string) (dbif.DB, error) { name = strings.ToLower(name) switch name { case "geo", "geoip", "geoip2": @@ -99,9 +105,18 @@ func GetIPDBbyName(name string) (db dbif.DB) { } func Update() { - qqwry.Download(QQWryPath) - zxipv6wry.Download(ZXIPv6WryPath) - cdn.Download(CDNPath) + _, err := qqwry.Download(QQWryPath) + if err != nil { + log.Fatalln("Database QQWry download failed:", err) + } + _, err = zxipv6wry.Download(ZXIPv6WryPath) + if err != nil { + log.Fatalln("Database ZXIPv6Wry download failed:", err) + } + _, err = cdn.Download(CDNPath) + if err != nil { + log.Fatalln("Database CDN download failed:", err) + } } func Find(typ dbif.QueryType, query string) string { diff --git a/pkg/cdn/cdn.go b/pkg/cdn/cdn.go index 6ff5419..99b1e3a 100644 --- a/pkg/cdn/cdn.go +++ b/pkg/cdn/cdn.go @@ -25,7 +25,7 @@ func (r CDNResult) String() string { return r.Name } -func NewCDN(filePath string) *CDN { +func NewCDN(filePath string) (*CDN, error) { cdnDist := make(CDNDist) cdnData := make([]byte, 0) @@ -34,26 +34,26 @@ func NewCDN(filePath string) *CDN { log.Println("文件不存在,尝试从网络获取最新CDN数据库") cdnData, err = Download(filePath) if err != nil { - os.Exit(1) + return nil, err } } else { cdnFile, err := os.OpenFile(filePath, os.O_RDONLY, 0400) if err != nil { - panic(err) + return nil, err } defer cdnFile.Close() cdnData, err = ioutil.ReadAll(cdnFile) if err != nil { - panic(err) + return nil, err } } err = json.Unmarshal(cdnData, &cdnDist) if err != nil { - panic("cdn data parse failed!") + return nil, err } - return &CDN{Data: cdnDist} + return &CDN{Data: cdnDist}, nil } func (db CDN) Find(query string, params ...string) (result fmt.Stringer, err error) { diff --git a/pkg/cdn/update.go b/pkg/cdn/update.go index d7c3488..2aa7e23 100644 --- a/pkg/cdn/update.go +++ b/pkg/cdn/update.go @@ -1,49 +1,28 @@ package cdn import ( - "io/ioutil" "log" - "net/http" "github.com/zu1k/nali/pkg/common" ) -func Download(filePath string) (data []byte, err error) { - data, err = getData() - if err != nil { - log.Printf("CDN数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) - log.Println("下载链接:", githubUrl) - return - } - - common.ExistThenRemove(filePath) - if err := ioutil.WriteFile(filePath, data, 0644); err == nil { - log.Printf("已将最新的 CDN数据库 保存到本地: %s \n", filePath) - } - return -} - const ( githubUrl = "https://raw.githubusercontent.com/SukkaLab/cdn/master/dist/cdn.json" jsdelivrUrl = "https://cdn.jsdelivr.net/gh/SukkaLab/cdn/dist/cdn.json" ) -func getData() (data []byte, err error) { - resp, err := http.Get(jsdelivrUrl) +func Download(filePath ...string) (data []byte, err error) { + data, err = common.GetHttpClient().Get(jsdelivrUrl, githubUrl) if err != nil { - return nil, err + log.Printf("CDN数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) + log.Println("下载链接:", githubUrl) + return } - if resp.StatusCode != 200 { - resp, err = http.Get(githubUrl) - if err != nil { - return nil, err + + if len(filePath) == 1 { + if err := common.SaveFile(filePath[0], data); err == nil { + log.Printf("已将最新的 CDN数据库 保存到本地: %s \n", filePath) } } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return body, nil + return } diff --git a/pkg/common/dbtool.go b/pkg/common/dbtool.go index 011e186..79d6b8e 100644 --- a/pkg/common/dbtool.go +++ b/pkg/common/dbtool.go @@ -1,24 +1,8 @@ package common -import ( - "log" - "os" -) - func ByteToUInt32(data []byte) uint32 { i := uint32(data[0]) & 0xff i |= (uint32(data[1]) << 8) & 0xff00 i |= (uint32(data[2]) << 16) & 0xff0000 return i } - -func ExistThenRemove(filePath string) { - _, err := os.Stat(filePath) - if err == nil { - err = os.Remove(filePath) - if err != nil { - log.Fatalln("旧文件删除失败", err.Error()) - os.Exit(1) - } - } -} diff --git a/pkg/common/httpclient.go b/pkg/common/httpclient.go new file mode 100644 index 0000000..9275114 --- /dev/null +++ b/pkg/common/httpclient.go @@ -0,0 +1,57 @@ +package common + +import ( + "io/ioutil" + "net/http" + "time" +) + +const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.105 Safari/537.36" + +type HttpClient struct { + *http.Client +} + +var httpClient *HttpClient + +func init() { + httpClient = &HttpClient{http.DefaultClient} + httpClient.Timeout = time.Second * 30 + httpClient.Transport = &http.Transport{ + TLSHandshakeTimeout: time.Second * 5, + IdleConnTimeout: time.Second * 20, + ResponseHeaderTimeout: time.Second * 20, + ExpectContinueTimeout: time.Second * 20, + } +} + +func GetHttpClient() *HttpClient { + c := *httpClient + return &c +} + +func (c *HttpClient) Get(urls ...string) (body []byte, err error) { + var req *http.Request + var resp *http.Response + + for _, url := range urls { + req, err = http.NewRequest(http.MethodGet, url, nil) + if err != nil { + continue + } + req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8") + req.Header.Set("User-Agent", UserAgent) + resp, err = c.Do(req) + + if err == nil && resp != nil && resp.StatusCode == 200 { + defer resp.Body.Close() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + continue + } + return + } + } + + return nil, err +} diff --git a/pkg/common/savefile.go b/pkg/common/savefile.go new file mode 100644 index 0000000..0f3b9bc --- /dev/null +++ b/pkg/common/savefile.go @@ -0,0 +1,21 @@ +package common + +import ( + "io/ioutil" + "log" + "os" +) + +func SaveFile(path string, data []byte) (err error) { + // Remove file if exist + _, err = os.Stat(path) + if err == nil { + err = os.Remove(path) + if err != nil { + log.Fatalln("旧文件删除失败", err.Error()) + } + } + + // save file + return ioutil.WriteFile(path, data, 0644) +} diff --git a/pkg/dbif/db.go b/pkg/dbif/db.go index c5fc4a6..7239ff0 100644 --- a/pkg/dbif/db.go +++ b/pkg/dbif/db.go @@ -24,10 +24,10 @@ type DB interface { } var ( - _ DB = qqwry.QQwry{} - _ DB = zxipv6wry.ZXwry{} - _ DB = ipip.IPIPFree{} - _ DB = geoip.GeoIP{} - _ DB = ip2region.Ip2Region{} - _ DB = cdn.CDN{} + _ DB = &qqwry.QQwry{} + _ DB = &zxipv6wry.ZXwry{} + _ DB = &ipip.IPIPFree{} + _ DB = &geoip.GeoIP{} + _ DB = &ip2region.Ip2Region{} + _ DB = &cdn.CDN{} ) diff --git a/pkg/geoip/geoip.go b/pkg/geoip/geoip.go index 1f75b05..c7ea538 100644 --- a/pkg/geoip/geoip.go +++ b/pkg/geoip/geoip.go @@ -16,20 +16,19 @@ type GeoIP struct { } // new geoip from database file -func NewGeoIP(filePath string) (geoip GeoIP) { +func NewGeoIP(filePath string) (*GeoIP, error) { // 判断文件是否存在 _, err := os.Stat(filePath) if err != nil && os.IsNotExist(err) { log.Println("文件不存在,请自行下载 Geoip2 City库,并保存在", filePath) - os.Exit(1) + return nil, err } else { db, err := geoip2.Open(filePath) if err != nil { log.Fatal(err) } - geoip = GeoIP{db: db} + return &GeoIP{db: db}, nil } - return } func (g GeoIP) Find(query string, params ...string) (result fmt.Stringer, err error) { diff --git a/pkg/ip2region/ip2region.go b/pkg/ip2region/ip2region.go index fb46c31..adc7ffd 100644 --- a/pkg/ip2region/ip2region.go +++ b/pkg/ip2region/ip2region.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/lionsoul2014/ip2region/binding/golang/ip2region" - "github.com/zu1k/nali/pkg/common" ) @@ -15,24 +14,24 @@ type Ip2Region struct { db *ip2region.Ip2Region } -func NewIp2Region(filePath string) Ip2Region { +func NewIp2Region(filePath string) (*Ip2Region, error) { _, err := os.Stat(filePath) if err != nil && os.IsNotExist(err) { log.Println("文件不存在,尝试从网络获取最新 ip2region 库") _, err = Download(filePath) if err != nil { - os.Exit(1) + return nil, err } } region, err := ip2region.New(filePath) if err != nil { - panic(err) + return nil, err } - return Ip2Region{ + return &Ip2Region{ db: region, - } + }, nil } func (db Ip2Region) Find(query string, params ...string) (result fmt.Stringer, err error) { diff --git a/pkg/ip2region/update.go b/pkg/ip2region/update.go index 1d13ec9..af7fe38 100644 --- a/pkg/ip2region/update.go +++ b/pkg/ip2region/update.go @@ -1,49 +1,28 @@ package ip2region import ( - "io/ioutil" "log" - "net/http" "github.com/zu1k/nali/pkg/common" ) -func Download(filePath string) (data []byte, err error) { - data, err = getData() - if err != nil { - log.Printf("CDN数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) - log.Println("下载链接:", githubUrl) - return - } - - common.ExistThenRemove(filePath) - if err := ioutil.WriteFile(filePath, data, 0644); err == nil { - log.Printf("已将最新的 ip2region 保存到本地: %s \n", filePath) - } - return -} - const ( githubUrl = "https://raw.githubusercontent.com/lionsoul2014/ip2region/master/data/ip2region.db" jsdelivrUrl = "https://cdn.jsdelivr.net/gh/lionsoul2014/ip2region/data/ip2region.db" ) -func getData() (data []byte, err error) { - resp, err := http.Get(jsdelivrUrl) +func Download(filePath ...string) (data []byte, err error) { + data, err = common.GetHttpClient().Get(jsdelivrUrl, githubUrl) if err != nil { - return nil, err + log.Printf("CDN数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) + log.Println("下载链接:", githubUrl) + return } - if resp.StatusCode != 200 { - resp, err = http.Get(githubUrl) - if err != nil { - return nil, err + + if len(filePath) == 1 { + if err := common.SaveFile(filePath[0], data); err == nil { + log.Println("已将最新的 ip2region 保存到本地:", filePath[0]) } } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return body, nil + return } diff --git a/pkg/ipip/ipipfree.go b/pkg/ipip/ipipfree.go index 0d2e4ef..60c8b77 100644 --- a/pkg/ipip/ipipfree.go +++ b/pkg/ipip/ipipfree.go @@ -12,21 +12,18 @@ type IPIPFree struct { *ipdb.City } -func NewIPIPFree(filePath string) IPIPFree { +func NewIPIPFree(filePath string) (*IPIPFree, error) { _, err := os.Stat(filePath) if err != nil && os.IsNotExist(err) { log.Printf("IPIP数据库不存在,请手动下载解压后保存到本地: %s \n", filePath) log.Println("下载链接: https://www.ipip.net/product/ip.html") - os.Exit(1) - return IPIPFree{} + return nil, err } else { db, err := ipdb.NewCity(filePath) if err != nil { - log.Fatalln("IPIP 数据库 初始化失败") - log.Fatal(err) - os.Exit(1) + return nil, err } - return IPIPFree{City: db} + return &IPIPFree{City: db}, nil } } diff --git a/pkg/qqwry/qqwry.go b/pkg/qqwry/qqwry.go index c24c1fc..fb0a9db 100644 --- a/pkg/qqwry/qqwry.go +++ b/pkg/qqwry/qqwry.go @@ -19,7 +19,7 @@ type QQwry struct { } // NewQQwry new database from path -func NewQQwry(filePath string) QQwry { +func NewQQwry(filePath string) (*QQwry, error) { var fileData []byte var fileInfo common.FileData @@ -28,18 +28,18 @@ func NewQQwry(filePath string) QQwry { log.Println("文件不存在,尝试从网络获取最新纯真 IP 库") fileData, err = Download(filePath) if err != nil { - os.Exit(1) + return nil, err } } else { fileInfo.FileBase, err = os.OpenFile(filePath, os.O_RDONLY, 0400) if err != nil { - panic(err) + return nil, err } defer fileInfo.FileBase.Close() fileData, err = ioutil.ReadAll(fileInfo.FileBase) if err != nil { - panic(err) + return nil, err } } fileInfo.Data = fileData @@ -48,12 +48,12 @@ func NewQQwry(filePath string) QQwry { start := binary.LittleEndian.Uint32(buf[:4]) end := binary.LittleEndian.Uint32(buf[4:]) - return QQwry{ + return &QQwry{ IPDB: common.IPDB{ Data: &fileInfo, IPNum: (end-start)/7 + 1, }, - } + }, nil } func (db QQwry) Find(query string, params ...string) (result fmt.Stringer, err error) { diff --git a/pkg/qqwry/update.go b/pkg/qqwry/update.go index 9905f84..41d51b4 100644 --- a/pkg/qqwry/update.go +++ b/pkg/qqwry/update.go @@ -6,42 +6,43 @@ import ( "encoding/binary" "io/ioutil" "log" - "net/http" "github.com/zu1k/nali/pkg/common" ) -func Download(filePath string) (data []byte, err error) { - data, err = getData() +func Download(filePath ...string) (data []byte, err error) { + data, err = downloadAndDecrypt() if err != nil { log.Printf("纯真IP库下载失败,请手动下载解压后保存到本地: %s \n", filePath) log.Println("下载链接: https://qqwry.mirror.noc.one/qqwry.rar") return } - common.ExistThenRemove(filePath) - if err = ioutil.WriteFile(filePath, data, 0644); err == nil { - log.Printf("已将最新的 纯真IP库 保存到本地: %s ", filePath) + + if len(filePath) == 1 { + if err := common.SaveFile(filePath[0], data); err == nil { + log.Println("已将最新的 纯真IP库 保存到本地:", filePath) + } } return } -func getData() (data []byte, err error) { - resp, err := http.Get("https://qqwry.mirror.noc.one/qqwry.rar") - if err != nil { - return - } - defer resp.Body.Close() +const ( + mirror = "https://qqwry.mirror.noc.one/qqwry.rar" + key = "https://qqwry.mirror.noc.one/copywrite.rar" +) - body, err := ioutil.ReadAll(resp.Body) +func downloadAndDecrypt() (data []byte, err error) { + data, err = common.GetHttpClient().Get(mirror) if err != nil { - return + return nil, err } + key, err := getCopyWriteKey() if err != nil { - return + return nil, err } - return unRar(body, key) + return unRar(data, key) } func unRar(data []byte, key uint32) ([]byte, error) { @@ -62,15 +63,10 @@ func unRar(data []byte, key uint32) ([]byte, error) { } func getCopyWriteKey() (uint32, error) { - resp, err := http.Get("https://qqwry.mirror.noc.one/copywrite.rar") + body, err := common.GetHttpClient().Get(key) if err != nil { return 0, err } - defer resp.Body.Close() - if body, err := ioutil.ReadAll(resp.Body); err != nil { - return 0, err - } else { - return binary.LittleEndian.Uint32(body[5*4:]), nil - } + return binary.LittleEndian.Uint32(body[5*4:]), nil } diff --git a/pkg/zxipv6wry/update.go b/pkg/zxipv6wry/update.go index f8af2ac..44c75cf 100644 --- a/pkg/zxipv6wry/update.go +++ b/pkg/zxipv6wry/update.go @@ -4,47 +4,41 @@ import ( "io" "io/ioutil" "log" - "net/http" "os" "github.com/saracen/go7z" "github.com/zu1k/nali/pkg/common" ) -func Download(filePath string) (data []byte, err error) { +func Download(filePath ...string) (data []byte, err error) { data, err = getData() if err != nil { log.Printf("ZX IPv6数据库下载失败,请手动下载解压后保存到本地: %s \n", filePath) log.Println("下载链接: https://ip.zxinc.org/ip.7z") return } - common.ExistThenRemove(filePath) - if err = ioutil.WriteFile(filePath, data, 0644); err == nil { - log.Printf("已将最新的 ZX IPv6数据库 保存到本地: %s ", filePath) + + if len(filePath) == 1 { + if err := common.SaveFile(filePath[0], data); err == nil { + log.Println("已将最新的 ZX IPv6数据库 保存到本地:", filePath[0]) + } } return - } -func getData() (data []byte, err error) { - resp, err := http.Get("https://ip.zxinc.org/ip.7z") - if err != nil { - return nil, err - } - defer resp.Body.Close() +const ( + zx = "https://ip.zxinc.org/ip.7z" +) - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } +func getData() (data []byte, err error) { + data, err = common.GetHttpClient().Get(zx) file7z, err := ioutil.TempFile("", "*") if err != nil { - panic(err) + return nil, err } defer os.Remove(file7z.Name()) - - if err := ioutil.WriteFile(file7z.Name(), body, 0644); err == nil { + if err := common.SaveFile(file7z.Name(), data); err == nil { return Un7z(file7z.Name()) } return @@ -53,21 +47,21 @@ func getData() (data []byte, err error) { func Un7z(filePath string) (data []byte, err error) { sz, err := go7z.OpenReader(filePath) if err != nil { - panic(err) + return nil, err } defer sz.Close() fileNoNeed, err := ioutil.TempFile("", "*") if err != nil { - panic(err) + return nil, err } fileNeed, err := ioutil.TempFile("", "*") if err != nil { - panic(err) + return nil, err } if err != nil { - panic(err) + return nil, err } for { hdr, err := sz.Next() @@ -75,7 +69,7 @@ func Un7z(filePath string) (data []byte, err error) { break // End of archive } if err != nil { - panic(err) + return nil, err } if hdr.Name == "ipv6wry.db" { @@ -90,7 +84,7 @@ func Un7z(filePath string) (data []byte, err error) { } err = fileNoNeed.Close() if err != nil { - panic(err) + return nil, err } defer os.Remove(fileNoNeed.Name()) defer os.Remove(fileNeed.Name()) diff --git a/pkg/zxipv6wry/zxipv6wry.go b/pkg/zxipv6wry/zxipv6wry.go index 4db88ec..13ae7b8 100644 --- a/pkg/zxipv6wry/zxipv6wry.go +++ b/pkg/zxipv6wry/zxipv6wry.go @@ -18,7 +18,7 @@ type ZXwry struct { common.IPDB } -func NewZXwry(filePath string) ZXwry { +func NewZXwry(filePath string) (*ZXwry, error) { var fileData []byte var fileInfo common.FileData @@ -27,28 +27,28 @@ func NewZXwry(filePath string) ZXwry { log.Println("文件不存在,尝试从网络获取最新ZX IPv6数据库") fileData, err = Download(filePath) if err != nil { - os.Exit(1) + return nil, err } } else { fileInfo.FileBase, err = os.OpenFile(filePath, os.O_RDONLY, 0400) if err != nil { - panic(err) + return nil, err } defer fileInfo.FileBase.Close() fileData, err = ioutil.ReadAll(fileInfo.FileBase) if err != nil { - panic(err) + return nil, err } } fileInfo.Data = fileData - return ZXwry{ + return &ZXwry{ IPDB: common.IPDB{ Data: &fileInfo, }, - } + }, nil } func (db ZXwry) Find(query string, params ...string) (result fmt.Stringer, err error) {