diff --git a/codes_v2.go b/codes_v2.go index c2198c8..309f4e8 100644 --- a/codes_v2.go +++ b/codes_v2.go @@ -8,9 +8,11 @@ import ( "github.com/injoyai/ios" "github.com/injoyai/ios/client" "github.com/injoyai/logs" + "github.com/injoyai/tdx/internal/gbbq" "github.com/injoyai/tdx/internal/xorms" "github.com/injoyai/tdx/protocol" "github.com/robfig/cron/v3" + "os" "path/filepath" "time" "xorm.io/xorm" @@ -18,9 +20,15 @@ import ( type Codes2Option func(*Codes2) -func WithFilename(filename string) Codes2Option { +func WithDBFilename(filename string) Codes2Option { return func(c *Codes2) { - c.filename = filename + c.dbFilename = filename + } +} + +func WithTempDir(dir string) Codes2Option { + return func(c *Codes2) { + c.tempDir = dir } } @@ -30,9 +38,9 @@ func WithSpec(spec string) Codes2Option { } } -func WithKey(key string) Codes2Option { +func WithUpdateKey(key string) Codes2Option { return func(c *Codes2) { - c.key = key + c.updateKey = key } } @@ -42,20 +50,37 @@ func WithRetry(retry int) Codes2Option { } } +func WithDial(dial ios.DialFunc, op ...client.Option) Codes2Option { + return func(c *Codes2) { + c.dial = dial + c.dialOption = op + } +} + +func WithDialOption(op ...client.Option) Codes2Option { + return func(c *Codes2) { + c.dialOption = op + } +} + func NewCodes2(op ...Codes2Option) (*Codes2, error) { cs := &Codes2{ - filename: filepath.Join(DefaultDatabaseDir, "codes.db"), + dbFilename: filepath.Join(DefaultDatabaseDir, "codes2.db"), + tempDir: filepath.Join(DefaultDataDir, "temp"), spec: "10 0 9 * * *", - key: "codes", + updateKey: "codes", retry: 3, dial: NewRangeDial(Hosts), dialOption: nil, + m: maps.NewGeneric[string, *CodeModel](), } for _, o := range op { o(cs) } + os.MkdirAll(cs.tempDir, 0777) + var err error // 初始化连接 @@ -65,7 +90,7 @@ func NewCodes2(op ...Codes2Option) (*Codes2, error) { } // 初始化数据库 - cs.db, err = xorms.NewSqlite(cs.filename) + cs.db, err = xorms.NewSqlite(cs.dbFilename) if err != nil { return nil, err } @@ -103,9 +128,10 @@ func NewCodes2(op ...Codes2Option) (*Codes2, error) { var _ ICodes = &Codes2{} type Codes2 struct { - filename string //数据库文件 + dbFilename string //数据库文件 + tempDir string //临时目录 spec string //定时规则 - key string //标识 + updateKey string //标识 retry int //重试次数 dial ios.DialFunc //连接 dialOption []client.Option // @@ -147,11 +173,11 @@ func (this *Codes2) GetETFs(limit ...int) []string { func (this *Codes2) updated() (bool, error) { update := new(UpdateModel) { //查询或者插入一条数据 - has, err := this.db.Where("`Key`=?", this.key).Get(update) + has, err := this.db.Where("`Key`=?", this.updateKey).Get(update) if err != nil { return true, err } else if !has { - update.Key = this.key + update.Key = this.updateKey if _, err = this.db.Insert(update); err != nil { return true, err } @@ -188,6 +214,7 @@ func (this *Codes2) Update() error { etfs := []string(nil) for _, v := range codes { fullCode := v.FullCode() + this.m.Set(fullCode, v) switch { case protocol.IsStock(fullCode): stocks = append(stocks, fullCode) @@ -255,7 +282,32 @@ func (this *Codes2) update() ([]*CodeModel, error) { } } - //4. 插入或者更新数据库 + //4. 获取gbbq + ss, err := gbbq.DownloadAndDecode(this.tempDir) + if err != nil { + logs.Err(err) + return nil, err + } + + mStock := map[string]gbbq.Stock{} + for _, v := range ss { + mStock[protocol.AddPrefix(v.Code)] = v + } + + //5. 赋值流通股和总股本 + for _, v := range insert { + if protocol.IsStock(v.FullCode()) { + v.FloatStock, v.TotalStock = ss.GetStock(v.Code) + } + } + for _, v := range update { + if stock, ok := mStock[v.FullCode()]; ok { + v.FloatStock = stock.Float + v.TotalStock = stock.Total + } + } + + //6. 插入或者更新数据库 err = this.db.SessionFunc(func(session *xorm.Session) error { for _, v := range mCode { if _, err = session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Delete(v); err != nil { @@ -279,6 +331,6 @@ func (this *Codes2) update() ([]*CodeModel, error) { } //更新时间 - _, err = this.db.Where("`Key`=?", this.key).Update(&UpdateModel{Time: time.Now().Unix()}) + _, err = this.db.Where("`Key`=?", this.updateKey).Update(&UpdateModel{Time: time.Now().Unix()}) return list, err }