diff --git a/internal/app/parse.go b/internal/app/parse.go index 54606aa..2376538 100644 --- a/internal/app/parse.go +++ b/internal/app/parse.go @@ -22,20 +22,12 @@ var ( ) // init ip db content -func InitIPDB() { - qqip = qqwry.NewQQwry(filepath.Join(constant.HomePath, "qqwry.dat")) - //geoip = geoip2.NewGeoIP(filepath.Join(constant.HomePath, "GeoLite2-City.mmdb")) - - db = qqip -} - -// set db to use -func SetDB(dbName ipdb.IPDBType) { - switch dbName { +func InitIPDB(ipdbtype ipdb.IPDBType) { + switch ipdbtype { case ipdb.GEOIP2: - db = geoip + db = geoip2.NewGeoIP(filepath.Join(constant.HomePath, "GeoLite2-City.mmdb")) case ipdb.QQIP: - db = qqip + db = qqwry.NewQQwry(filepath.Join(constant.HomePath, "qqwry.dat")) } } diff --git a/main.go b/main.go index d692626..fcb172c 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,9 @@ import ( "log" "os" "path/filepath" + "strings" + + "github.com/zu1k/nali/internal/ipdb" "github.com/zu1k/nali/internal/app" @@ -13,7 +16,7 @@ import ( func main() { setHomePath() - app.InitIPDB() + app.InitIPDB(getIPDBType()) cmd.Execute() } @@ -30,3 +33,16 @@ func setHomePath() { } } } + +func getIPDBType() ipdb.IPDBType { + dbname := os.Getenv("NALI_DB") + dbname = strings.ToLower(dbname) + switch dbname { + case "geo", "geoip", "geoip2": + return ipdb.GEOIP2 + case "chunzhen", "qqip", "qqwry": + return ipdb.QQIP + default: + return ipdb.QQIP + } +} diff --git a/pkg/geoip/geoip.go b/pkg/geoip/geoip.go index 069095a..792fc8b 100644 --- a/pkg/geoip/geoip.go +++ b/pkg/geoip/geoip.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "net" + "os" "github.com/oschwald/geoip2-golang" ) @@ -14,12 +15,20 @@ type GeoIP struct { } // new geoip from db file -func NewGeoIP(filePath string) GeoIP { - db, err := geoip2.Open(filePath) - if err != nil { - log.Fatal(err) +func NewGeoIP(filePath string) (geoip GeoIP) { + // 判断文件是否存在 + _, err := os.Stat(filePath) + if err != nil && os.IsNotExist(err) { + log.Println("文件不存在,请自行下载 Geoip2 City库,并保存在", filePath) + os.Exit(1) + } else { + db, err := geoip2.Open(filePath) + if err != nil { + log.Fatal(err) + } + geoip = GeoIP{db: db} } - return GeoIP{db: db} + return } // find ip info diff --git a/pkg/geoip/update.go b/pkg/geoip/update.go new file mode 100644 index 0000000..c80faa3 --- /dev/null +++ b/pkg/geoip/update.go @@ -0,0 +1 @@ +package geoip