Files
injoyai-tdx/codes.go
2025-05-13 11:17:01 +08:00

341 lines
8.1 KiB
Go

package tdx
import (
"errors"
"github.com/injoyai/conv"
"github.com/injoyai/ios/client"
"github.com/injoyai/logs"
"github.com/injoyai/tdx/protocol"
"github.com/robfig/cron/v3"
"math"
"os"
"path/filepath"
"time"
"xorm.io/core"
"xorm.io/xorm"
)
// DefaultCodes 增加单例,部分数据需要通过Codes里面的信息计算
var DefaultCodes *Codes
func DialCodes(filename string, op ...client.Option) (*Codes, error) {
c, err := DialDefault(op...)
if err != nil {
return nil, err
}
return NewCodes(c, filename)
}
func NewCodes(c *Client, filenames ...string) (*Codes, error) {
//如果没有指定文件名,则使用默认
defaultFilename := filepath.Join(DefaultDatabaseDir, "codes.db")
filename := conv.Default[string](defaultFilename, filenames...)
filename = conv.Select[string](filename == "", defaultFilename, filename)
//如果文件夹不存在就创建
dir, _ := filepath.Split(filename)
_ = os.MkdirAll(dir, 0777)
//连接数据库
db, err := xorm.NewEngine("sqlite", filename)
if err != nil {
return nil, err
}
db.SetMapper(core.SameMapper{})
db.DB().SetMaxOpenConns(1)
if err := db.Sync2(new(CodeModel)); err != nil {
return nil, err
}
if err := db.Sync2(new(UpdateModel)); err != nil {
return nil, err
}
update := new(UpdateModel)
{ //查询或者插入一条数据
has, err := db.Get(update)
if err != nil {
return nil, err
} else if !has {
if _, err := db.Insert(update); err != nil {
return nil, err
}
}
}
cc := &Codes{
Client: c,
db: db,
}
{ //设置定时器,每天早上9点更新数据
task := cron.New(cron.WithSeconds())
task.AddFunc("0 0 9 * * *", func() {
for i := 0; i < 3; i++ {
if err := cc.Update(); err == nil {
return
}
logs.Err(err)
<-time.After(time.Minute * 5)
}
})
task.Start()
}
{ //判断是否更新过,更新过则不更新
now := time.Now()
node := time.Date(now.Year(), now.Month(), now.Day(), 9, 0, 0, 0, time.Local)
updateTime := time.Unix(update.Time, 0)
if now.Sub(node) > 0 {
//当前时间在9点之后,且更新时间在9点之前,需要更新
if updateTime.Sub(node) < 0 {
return cc, cc.Update()
}
} else {
//当前时间在9点之前,且更新时间在上个节点之前
if updateTime.Sub(node.Add(time.Hour*24)) < 0 {
return cc, cc.Update()
}
}
}
//从缓存中加载
return cc, cc.Update(true)
}
type Codes struct {
*Client //客户端
db *xorm.Engine //数据库实例
Map map[string]*CodeModel //股票缓存
list []*CodeModel //列表方式缓存
exchanges map[string][]string //交易所缓存
}
// GetName 获取股票名称
func (this *Codes) GetName(code string) string {
if v, ok := this.Map[code]; ok {
return v.Name
}
return "未知"
}
// GetStocks 获取股票代码,sh6xxx sz0xx sz30xx
func (this *Codes) GetStocks(limits ...int) []string {
limit := conv.Default[int](-1, limits...)
ls := []string(nil)
for _, m := range this.list {
code := m.FullCode()
if protocol.IsStock(code) {
ls = append(ls, code)
}
if limit > 0 && len(ls) >= limit {
break
}
}
return ls
}
func (this *Codes) Get(code string) *CodeModel {
return this.Map[code]
}
// GetExchange 获取股票交易所,这里的参数不需要带前缀
func (this *Codes) GetExchange(code string) protocol.Exchange {
if len(code) == 6 {
switch {
case code[:1] == "6":
return protocol.ExchangeSH
case code[:1] == "0":
return protocol.ExchangeSZ
case code[:2] == "30":
return protocol.ExchangeSZ
}
}
var exchange string
exchanges := this.exchanges[code]
if len(exchanges) >= 1 {
exchange = exchanges[0]
}
if len(code) == 8 {
exchange = code[0:2]
}
switch exchange {
case protocol.ExchangeSH.String():
return protocol.ExchangeSH
case protocol.ExchangeSZ.String():
return protocol.ExchangeSZ
default:
return protocol.ExchangeSH
}
}
func (this *Codes) AddExchange(code string) string {
if exchanges := this.exchanges[code]; len(exchanges) == 1 {
return exchanges[0] + code
}
if len(code) == 6 {
switch {
case code[:1] == "6":
return protocol.ExchangeSH.String() + code
case code[:1] == "0":
return protocol.ExchangeSZ.String() + code
case code[:2] == "30":
return protocol.ExchangeSZ.String() + code
}
return this.GetExchange(code).String() + code
}
return code
}
// Update 更新数据,从服务器或者数据库
func (this *Codes) Update(byDB ...bool) error {
codes, err := this.GetCodes(len(byDB) > 0 && byDB[0])
if err != nil {
return err
}
codeMap := make(map[string]*CodeModel)
exchanges := make(map[string][]string)
for _, code := range codes {
codeMap[code.Exchange+code.Code] = code
exchanges[code.Code] = append(exchanges[code.Code], code.Exchange)
}
this.Map = codeMap
this.list = codes
this.exchanges = exchanges
//更新时间
_, err = this.db.Update(&UpdateModel{Time: time.Now().Unix()})
return err
}
// GetCodes 更新股票并返回结果
func (this *Codes) GetCodes(byDatabase bool) ([]*CodeModel, error) {
if this.Client == nil {
return nil, errors.New("client is nil")
}
//2. 查询数据库所有股票
list := []*CodeModel(nil)
if err := this.db.Find(&list); err != nil {
return nil, err
}
//如果是从缓存读取,则返回结果
if byDatabase {
return list, nil
}
mCode := make(map[string]*CodeModel, len(list))
for _, v := range list {
mCode[v.Code] = v
}
//3. 从服务器获取所有股票代码
insert := []*CodeModel(nil)
update := []*CodeModel(nil)
for _, exchange := range []protocol.Exchange{protocol.ExchangeSH, protocol.ExchangeSZ} {
resp, err := this.Client.GetCodeAll(exchange)
if err != nil {
return nil, err
}
for _, v := range resp.List {
if _, ok := mCode[v.Code]; ok {
if mCode[v.Code].Name != v.Name {
mCode[v.Code].Name = v.Name
update = append(update, &CodeModel{
Name: v.Name,
Code: v.Code,
Exchange: exchange.String(),
Multiple: v.Multiple,
Decimal: v.Decimal,
LastPrice: v.LastPrice,
})
}
} else {
code := &CodeModel{
Name: v.Name,
Code: v.Code,
Exchange: exchange.String(),
Multiple: v.Multiple,
Decimal: v.Decimal,
LastPrice: v.LastPrice,
}
insert = append(insert, code)
list = append(list, code)
}
}
}
//4. 插入或者更新数据库
err := NewSessionFunc(this.db, func(session *xorm.Session) error {
for _, v := range insert {
if _, err := session.Insert(v); err != nil {
return err
}
}
for _, v := range update {
if _, err := session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Cols("Name,LastPrice").Update(v); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, err
}
return list, nil
}
type UpdateModel struct {
Time int64 //更新时间
}
func (*UpdateModel) TableName() string {
return "update"
}
type CodeModel struct {
ID int64 `json:"id"` //主键
Name string `json:"name"` //名称,有时候名称会变,例STxxx
Code string `json:"code" xorm:"index"` //代码
Exchange string `json:"exchange" xorm:"index"` //交易所
Multiple uint16 `json:"multiple"` //倍数
Decimal int8 `json:"decimal"` //小数位
LastPrice float64 `json:"lastPrice"` //昨收价格
EditDate int64 `json:"editDate" xorm:"updated"` //修改时间
InDate int64 `json:"inDate" xorm:"created"` //创建时间
}
func (*CodeModel) TableName() string {
return "codes"
}
func (this *CodeModel) FullCode() string {
return this.Exchange + this.Code
}
func (this *CodeModel) Price(p protocol.Price) protocol.Price {
return protocol.Price(float64(p) * math.Pow10(int(2-this.Decimal)))
return p * protocol.Price(math.Pow10(int(2-this.Decimal)))
}
func NewSessionFunc(db *xorm.Engine, fn func(session *xorm.Session) error) error {
session := db.NewSession()
defer session.Close()
if err := session.Begin(); err != nil {
session.Rollback()
return err
}
if err := fn(session); err != nil {
session.Rollback()
return err
}
if err := session.Commit(); err != nil {
session.Rollback()
return err
}
return nil
}