From 69d5fd5eea174d22ebea439543238c22b05be412 Mon Sep 17 00:00:00 2001
From: zu1k <i@lgf.im>
Date: Fri, 17 Jul 2020 11:37:03 +0800
Subject: [PATCH] add db choose

---
 internal/app/parse.go | 16 ++++------------
 main.go               | 18 +++++++++++++++++-
 pkg/geoip/geoip.go    | 19 ++++++++++++++-----
 pkg/geoip/update.go   |  1 +
 4 files changed, 36 insertions(+), 18 deletions(-)
 create mode 100644 pkg/geoip/update.go

diff --git a/internal/app/parse.go b/internal/app/parse.go
index 54606aa..2376538 100644
--- a/internal/app/parse.go
+++ b/internal/app/parse.go
@@ -22,20 +22,12 @@ var (
 )
 
 // init ip db content
-func InitIPDB() {
-	qqip = qqwry.NewQQwry(filepath.Join(constant.HomePath, "qqwry.dat"))
-	//geoip = geoip2.NewGeoIP(filepath.Join(constant.HomePath, "GeoLite2-City.mmdb"))
-
-	db = qqip
-}
-
-// set db to use
-func SetDB(dbName ipdb.IPDBType) {
-	switch dbName {
+func InitIPDB(ipdbtype ipdb.IPDBType) {
+	switch ipdbtype {
 	case ipdb.GEOIP2:
-		db = geoip
+		db = geoip2.NewGeoIP(filepath.Join(constant.HomePath, "GeoLite2-City.mmdb"))
 	case ipdb.QQIP:
-		db = qqip
+		db = qqwry.NewQQwry(filepath.Join(constant.HomePath, "qqwry.dat"))
 	}
 }
 
diff --git a/main.go b/main.go
index d692626..fcb172c 100644
--- a/main.go
+++ b/main.go
@@ -4,6 +4,9 @@ import (
 	"log"
 	"os"
 	"path/filepath"
+	"strings"
+
+	"github.com/zu1k/nali/internal/ipdb"
 
 	"github.com/zu1k/nali/internal/app"
 
@@ -13,7 +16,7 @@ import (
 
 func main() {
 	setHomePath()
-	app.InitIPDB()
+	app.InitIPDB(getIPDBType())
 	cmd.Execute()
 }
 
@@ -30,3 +33,16 @@ func setHomePath() {
 		}
 	}
 }
+
+func getIPDBType() ipdb.IPDBType {
+	dbname := os.Getenv("NALI_DB")
+	dbname = strings.ToLower(dbname)
+	switch dbname {
+	case "geo", "geoip", "geoip2":
+		return ipdb.GEOIP2
+	case "chunzhen", "qqip", "qqwry":
+		return ipdb.QQIP
+	default:
+		return ipdb.QQIP
+	}
+}
diff --git a/pkg/geoip/geoip.go b/pkg/geoip/geoip.go
index 069095a..792fc8b 100644
--- a/pkg/geoip/geoip.go
+++ b/pkg/geoip/geoip.go
@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"os"
 
 	"github.com/oschwald/geoip2-golang"
 )
@@ -14,12 +15,20 @@ type GeoIP struct {
 }
 
 // new geoip from db file
-func NewGeoIP(filePath string) GeoIP {
-	db, err := geoip2.Open(filePath)
-	if err != nil {
-		log.Fatal(err)
+func NewGeoIP(filePath string) (geoip GeoIP) {
+	// 判断文件是否存在
+	_, err := os.Stat(filePath)
+	if err != nil && os.IsNotExist(err) {
+		log.Println("文件不存在,请自行下载 Geoip2 City库,并保存在", filePath)
+		os.Exit(1)
+	} else {
+		db, err := geoip2.Open(filePath)
+		if err != nil {
+			log.Fatal(err)
+		}
+		geoip = GeoIP{db: db}
 	}
-	return GeoIP{db: db}
+	return
 }
 
 // find ip info
diff --git a/pkg/geoip/update.go b/pkg/geoip/update.go
new file mode 100644
index 0000000..c80faa3
--- /dev/null
+++ b/pkg/geoip/update.go
@@ -0,0 +1 @@
+package geoip