Files
injoyai-tdx/codes_v2.go
2025-11-21 16:08:46 +08:00

381 lines
8.2 KiB
Go

package tdx
import (
"errors"
"iter"
"os"
"path/filepath"
"time"
"github.com/injoyai/base/maps"
"github.com/injoyai/base/types"
"github.com/injoyai/conv"
"github.com/injoyai/ios"
"github.com/injoyai/ios/client"
"github.com/injoyai/logs"
"github.com/injoyai/tdx/lib/gbbq"
"github.com/injoyai/tdx/lib/xorms"
"github.com/injoyai/tdx/protocol"
"github.com/robfig/cron/v3"
"xorm.io/xorm"
)
type Codes2Option func(*Codes2)
func WithCodes2Database(filename string) Codes2Option {
return func(c *Codes2) {
c.dbFilename = filename
}
}
func WithCodes2TempDir(dir string) Codes2Option {
return func(c *Codes2) {
c.tempDir = dir
}
}
func WithCodes2Spec(spec string) Codes2Option {
return func(c *Codes2) {
c.spec = spec
}
}
func WithCodes2UpdateKey(key string) Codes2Option {
return func(c *Codes2) {
c.updateKey = key
}
}
func WithCodes2Retry(retry int) Codes2Option {
return func(c *Codes2) {
c.retry = retry
}
}
func WithCodes2Client(c *Client) Codes2Option {
return func(cs *Codes2) {
cs.c = c
}
}
func WithCodes2Dial(dial ios.DialFunc, op ...client.Option) Codes2Option {
return func(c *Codes2) {
c.dial = dial
c.dialOption = op
}
}
func WithCodes2DialOption(op ...client.Option) Codes2Option {
return func(c *Codes2) {
c.dialOption = op
}
}
func NewCodes2(op ...Codes2Option) (*Codes2, error) {
cs := &Codes2{
dbFilename: filepath.Join(DefaultDatabaseDir, "codes2.db"),
tempDir: filepath.Join(DefaultDataDir, "temp"),
spec: "10 0 9 * * *",
updateKey: "codes",
retry: DefaultRetry,
dial: NewRangeDial(Hosts),
dialOption: nil,
m: maps.NewGeneric[string, *CodeModel](),
}
for _, o := range op {
o(cs)
}
os.MkdirAll(cs.tempDir, 0777)
var err error
// 初始化连接
if cs.c == nil {
cs.c, err = DialWith(cs.dial, cs.dialOption...)
if err != nil {
return nil, err
}
}
// 初始化数据库
cs.db, err = xorms.NewSqlite(cs.dbFilename)
if err != nil {
return nil, err
}
if err = cs.db.Sync2(new(CodeModel), new(UpdateModel)); err != nil {
return nil, err
}
// 立即更新
err = cs.Update()
if err != nil {
return nil, err
}
// 定时更新
cr := cron.New(cron.WithSeconds())
_, err = cr.AddFunc(cs.spec, func() {
for i := 0; i == 0 || i < cs.retry; i++ {
if err := cs.Update(); err != nil {
logs.Err(err)
<-time.After(time.Minute * 5)
} else {
break
}
}
})
if err != nil {
return nil, err
}
cr.Start()
return cs, nil
}
var _ ICodes = &Codes2{}
type Codes2 struct {
dbFilename string //数据库文件
tempDir string //临时目录
spec string //定时规则
updateKey string //标识
retry int //重试次数
dial ios.DialFunc //连接
dialOption []client.Option //
/*
内部字段
*/
c *Client //
db *xorms.Engine //
stocks types.List[*CodeModel] //股票缓存
etfs types.List[*CodeModel] //etf缓存
indexes types.List[*CodeModel] //指数缓存
all types.List[*CodeModel] //全部缓存
m *maps.Generic[string, *CodeModel] //缓存
}
func (this *Codes2) Get(code string) *CodeModel {
v, _ := this.m.Get(code)
return v
}
func (this *Codes2) Iter() iter.Seq2[string, *CodeModel] {
return func(yield func(string, *CodeModel) bool) {
for _, code := range this.all {
if !yield(code.FullCode(), code) {
break
}
}
}
}
func (this *Codes2) GetName(code string) string {
v, _ := this.m.Get(code)
if v == nil {
return "未知"
}
return v.Name
}
func (this *Codes2) GetStocks(limit ...int) CodeModels {
size := conv.Default(this.stocks.Len(), limit...)
return CodeModels(this.stocks.Limit(size))
}
func (this *Codes2) GetStockCodes(limit ...int) []string {
return this.GetStocks(limit...).Codes()
}
func (this *Codes2) GetETFs(limit ...int) CodeModels {
size := conv.Default(this.etfs.Len(), limit...)
return CodeModels(this.etfs.Limit(size))
}
func (this *Codes2) GetETFCodes(limit ...int) []string {
return this.GetETFs(limit...).Codes()
}
func (this *Codes2) GetIndexes(limit ...int) CodeModels {
size := conv.Default(this.etfs.Len(), limit...)
return CodeModels(this.indexes.Limit(size))
}
func (this *Codes2) GetIndexCodes(limit ...int) []string {
return this.GetIndexes(limit...).Codes()
}
func (this *Codes2) updated() (bool, error) {
update := new(UpdateModel)
{ //查询或者插入一条数据
has, err := this.db.Where("`Key`=?", this.updateKey).Get(update)
if err != nil {
return true, err
} else if !has {
update.Key = this.updateKey
if _, err = this.db.Insert(update); err != nil {
return true, err
}
return false, nil
}
}
{ //判断是否更新过,更新过则不更新
now := time.Now()
node := time.Date(now.Year(), now.Month(), now.Day(), 9, 0, 0, 0, time.Local)
updateTime := time.Unix(update.Time, 0)
if now.Sub(node) > 0 {
//当前时间在9点之后,且更新时间在9点之前,需要更新
if updateTime.Sub(node) < 0 {
return false, nil
}
} else {
//当前时间在9点之前,且更新时间在上个节点之前
if updateTime.Sub(node.Add(time.Hour*24)) < 0 {
return false, nil
}
}
}
return true, nil
}
func (this *Codes2) Update() error {
codes, err := this.update()
if err != nil {
return err
}
stocks := []*CodeModel(nil)
etfs := []*CodeModel(nil)
indexes := []*CodeModel(nil)
for _, v := range codes {
fullCode := v.FullCode()
this.m.Set(fullCode, v)
switch {
case protocol.IsStock(fullCode):
stocks = append(stocks, v)
case protocol.IsETF(fullCode):
etfs = append(etfs, v)
case protocol.IsIndex(fullCode):
indexes = append(indexes, v)
}
}
this.stocks = stocks
this.etfs = etfs
this.indexes = indexes
this.all = codes
return nil
}
// GetCodes 更新股票并返回结果
func (this *Codes2) update() ([]*CodeModel, error) {
if this.c == nil {
return nil, errors.New("client is nil")
}
//2. 查询数据库所有股票
list := []*CodeModel(nil)
if err := this.db.Find(&list); err != nil {
return nil, err
}
//如果更新过,则不更新
updated, err := this.updated()
if err == nil && updated {
return list, nil
}
mCode := make(map[string]*CodeModel, len(list))
for _, v := range list {
mCode[v.FullCode()] = v
}
//3. 从服务器获取所有股票代码
insert := []*CodeModel(nil)
update := []*CodeModel(nil)
for _, exchange := range []protocol.Exchange{protocol.ExchangeSH, protocol.ExchangeSZ, protocol.ExchangeBJ} {
resp, err := this.c.GetCodeAll(exchange)
if err != nil {
return nil, err
}
for _, v := range resp.List {
code := &CodeModel{
Name: v.Name,
Code: v.Code,
Exchange: exchange.String(),
Multiple: v.Multiple,
Decimal: v.Decimal,
LastPrice: v.LastPrice,
}
if val, ok := mCode[exchange.String()+v.Code]; ok {
if val.Name != v.Name {
update = append(update, code)
}
delete(mCode, exchange.String()+v.Code)
} else {
insert = append(insert, code)
list = append(list, code)
}
}
}
//4. 获取gbbq
ss, err := gbbq.DownloadAndDecode(this.tempDir)
if err != nil {
logs.Err(err)
return nil, err
}
mStock := map[string]gbbq.Stock{}
for _, v := range ss {
mStock[protocol.AddPrefix(v.Code)] = v
}
//5. 赋值流通股和总股本
for _, v := range insert {
if protocol.IsStock(v.FullCode()) {
v.FloatStock, v.TotalStock = ss.GetStock(v.Code)
}
}
for _, v := range update {
if stock, ok := mStock[v.FullCode()]; ok {
v.FloatStock = stock.Float
v.TotalStock = stock.Total
}
}
//6. 插入或者更新数据库
err = this.db.SessionFunc(func(session *xorm.Session) error {
for _, v := range mCode {
if _, err = session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Delete(v); err != nil {
return err
}
}
for _, v := range insert {
if _, err := session.Insert(v); err != nil {
return err
}
}
for _, v := range update {
if _, err = session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Cols("Name,LastPrice").Update(v); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
//更新时间
_, err = this.db.Where("`Key`=?", this.updateKey).Update(&UpdateModel{Time: time.Now().Unix()})
return list, err
}