Files
injoyai-tdx/codes_v2.go
2025-11-13 16:57:53 +08:00

285 lines
6.0 KiB
Go

package tdx
import (
"errors"
"github.com/injoyai/base/maps"
"github.com/injoyai/base/types"
"github.com/injoyai/conv"
"github.com/injoyai/ios"
"github.com/injoyai/ios/client"
"github.com/injoyai/logs"
"github.com/injoyai/tdx/internal/xorms"
"github.com/injoyai/tdx/protocol"
"github.com/robfig/cron/v3"
"path/filepath"
"time"
"xorm.io/xorm"
)
type Codes2Option func(*Codes2)
func WithFilename(filename string) Codes2Option {
return func(c *Codes2) {
c.filename = filename
}
}
func WithSpec(spec string) Codes2Option {
return func(c *Codes2) {
c.spec = spec
}
}
func WithKey(key string) Codes2Option {
return func(c *Codes2) {
c.key = key
}
}
func WithRetry(retry int) Codes2Option {
return func(c *Codes2) {
c.retry = retry
}
}
func NewCodes2(op ...Codes2Option) (*Codes2, error) {
cs := &Codes2{
filename: filepath.Join(DefaultDatabaseDir, "codes.db"),
spec: "10 0 9 * * *",
key: "codes",
retry: 3,
dial: NewRangeDial(Hosts),
dialOption: nil,
}
for _, o := range op {
o(cs)
}
var err error
// 初始化连接
cs.c, err = DialWith(cs.dial, cs.dialOption...)
if err != nil {
return nil, err
}
// 初始化数据库
cs.db, err = xorms.NewSqlite(cs.filename)
if err != nil {
return nil, err
}
if err = cs.db.Sync2(new(CodeModel), new(UpdateModel)); err != nil {
return nil, err
}
// 立即更新
err = cs.Update()
if err != nil {
return nil, err
}
// 定时更新
cr := cron.New(cron.WithSeconds())
_, err = cr.AddFunc(cs.spec, func() {
for i := 0; i < 3; i++ {
if err := cs.Update(); err != nil {
logs.Err(err)
<-time.After(time.Minute * 5)
} else {
break
}
}
})
if err != nil {
return nil, err
}
cr.Start()
return cs, nil
}
var _ ICodes = &Codes2{}
type Codes2 struct {
filename string //数据库文件
spec string //定时规则
key string //标识
retry int //重试次数
dial ios.DialFunc //连接
dialOption []client.Option //
/*
内部字段
*/
c *Client //
db *xorms.Engine //
stocks types.List[string] //缓存
etfs types.List[string] //缓存
m *maps.Generic[string, *CodeModel] //缓存
}
func (this *Codes2) Get(code string) *CodeModel {
v, _ := this.m.Get(code)
return v
}
func (this *Codes2) GetName(code string) string {
v, _ := this.m.Get(code)
if v == nil {
return "未知"
}
return v.Name
}
func (this *Codes2) GetStocks(limit ...int) []string {
size := conv.Default(this.stocks.Len(), limit...)
return this.stocks.Limit(size)
}
func (this *Codes2) GetETFs(limit ...int) []string {
size := conv.Default(this.etfs.Len(), limit...)
return this.etfs.Limit(size)
}
func (this *Codes2) updated() (bool, error) {
update := new(UpdateModel)
{ //查询或者插入一条数据
has, err := this.db.Where("`Key`=?", this.key).Get(update)
if err != nil {
return true, err
} else if !has {
update.Key = this.key
if _, err = this.db.Insert(update); err != nil {
return true, err
}
return false, nil
}
}
{ //判断是否更新过,更新过则不更新
now := time.Now()
node := time.Date(now.Year(), now.Month(), now.Day(), 9, 0, 0, 0, time.Local)
updateTime := time.Unix(update.Time, 0)
if now.Sub(node) > 0 {
//当前时间在9点之后,且更新时间在9点之前,需要更新
if updateTime.Sub(node) < 0 {
return false, nil
}
} else {
//当前时间在9点之前,且更新时间在上个节点之前
if updateTime.Sub(node.Add(time.Hour*24)) < 0 {
return false, nil
}
}
}
return true, nil
}
func (this *Codes2) Update() error {
codes, err := this.update()
if err != nil {
return err
}
stocks := []string(nil)
etfs := []string(nil)
for _, v := range codes {
fullCode := v.FullCode()
switch {
case protocol.IsStock(fullCode):
stocks = append(stocks, fullCode)
case protocol.IsETF(fullCode):
etfs = append(etfs, fullCode)
}
}
this.stocks = stocks
this.etfs = etfs
return nil
}
// GetCodes 更新股票并返回结果
func (this *Codes2) update() ([]*CodeModel, error) {
if this.c == nil {
return nil, errors.New("client is nil")
}
//2. 查询数据库所有股票
list := []*CodeModel(nil)
if err := this.db.Find(&list); err != nil {
return nil, err
}
//如果更新过,则不更新
updated, err := this.updated()
if err == nil && updated {
return list, nil
}
mCode := make(map[string]*CodeModel, len(list))
for _, v := range list {
mCode[v.FullCode()] = v
}
//3. 从服务器获取所有股票代码
insert := []*CodeModel(nil)
update := []*CodeModel(nil)
for _, exchange := range []protocol.Exchange{protocol.ExchangeSH, protocol.ExchangeSZ, protocol.ExchangeBJ} {
resp, err := this.c.GetCodeAll(exchange)
if err != nil {
return nil, err
}
for _, v := range resp.List {
code := &CodeModel{
Name: v.Name,
Code: v.Code,
Exchange: exchange.String(),
Multiple: v.Multiple,
Decimal: v.Decimal,
LastPrice: v.LastPrice,
}
if val, ok := mCode[exchange.String()+v.Code]; ok {
if val.Name != v.Name {
update = append(update, code)
}
delete(mCode, exchange.String()+v.Code)
} else {
insert = append(insert, code)
list = append(list, code)
}
}
}
//4. 插入或者更新数据库
err = this.db.SessionFunc(func(session *xorm.Session) error {
for _, v := range mCode {
if _, err = session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Delete(v); err != nil {
return err
}
}
for _, v := range insert {
if _, err := session.Insert(v); err != nil {
return err
}
}
for _, v := range update {
if _, err = session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Cols("Name,LastPrice").Update(v); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
//更新时间
_, err = this.db.Where("`Key`=?", this.key).Update(&UpdateModel{Time: time.Now().Unix()})
return list, err
}