mirror of
https://github.com/injoyai/tdx.git
synced 2025-11-26 21:25:35 +08:00
338 lines
7.9 KiB
Go
338 lines
7.9 KiB
Go
package tdx
|
|
|
|
import (
|
|
"errors"
|
|
"github.com/injoyai/conv"
|
|
"github.com/injoyai/ios/client"
|
|
"github.com/injoyai/logs"
|
|
"github.com/injoyai/tdx/protocol"
|
|
"github.com/robfig/cron/v3"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
"xorm.io/core"
|
|
"xorm.io/xorm"
|
|
)
|
|
|
|
// DefaultCodes 增加单例,部分数据需要通过Codes里面的信息计算
|
|
var DefaultCodes *Codes
|
|
|
|
func DialCodes(filename string, op ...client.Option) (*Codes, error) {
|
|
c, err := DialDefault(op...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewCodes(c, filename)
|
|
}
|
|
|
|
func NewCodes(c *Client, filenames ...string) (*Codes, error) {
|
|
|
|
filename := conv.DefaultString("./codes.db", filenames...)
|
|
|
|
//如果文件夹不存在就创建
|
|
dir, _ := filepath.Split(filename)
|
|
_ = os.MkdirAll(dir, 0777)
|
|
|
|
//连接数据库
|
|
db, err := xorm.NewEngine("sqlite", filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
db.SetMapper(core.SameMapper{})
|
|
db.DB().SetMaxOpenConns(1)
|
|
if err := db.Sync2(new(CodeModel)); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := db.Sync2(new(UpdateModel)); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
update := new(UpdateModel)
|
|
{ //查询或者插入一条数据
|
|
has, err := db.Get(update)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if !has {
|
|
if _, err := db.Insert(update); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
cc := &Codes{
|
|
Client: c,
|
|
db: db,
|
|
}
|
|
|
|
{ //设置定时器,每天早上9点更新数据
|
|
task := cron.New(cron.WithSeconds())
|
|
task.AddFunc("0 0 9 * * *", func() {
|
|
for i := 0; i < 3; i++ {
|
|
if err := cc.Update(); err == nil {
|
|
return
|
|
}
|
|
logs.Err(err)
|
|
<-time.After(time.Minute * 5)
|
|
}
|
|
})
|
|
task.Start()
|
|
}
|
|
|
|
{ //判断是否更新过,更新过则不更新
|
|
now := time.Now()
|
|
node := time.Date(now.Year(), now.Month(), now.Day(), 9, 0, 0, 0, time.Local)
|
|
updateTime := time.Unix(update.Time, 0)
|
|
if now.Sub(node) > 0 {
|
|
//当前时间在9点之后,且更新时间在9点之前,需要更新
|
|
if updateTime.Sub(node) < 0 {
|
|
return cc, cc.Update()
|
|
}
|
|
} else {
|
|
//当前时间在9点之前,且更新时间在上个节点之前
|
|
if updateTime.Sub(node.Add(time.Hour*24)) < 0 {
|
|
return cc, cc.Update()
|
|
}
|
|
}
|
|
}
|
|
|
|
//从缓存中加载
|
|
return cc, cc.Update(true)
|
|
}
|
|
|
|
type Codes struct {
|
|
*Client //客户端
|
|
db *xorm.Engine //数据库实例
|
|
Map map[string]*CodeModel //股票缓存
|
|
list []*CodeModel //列表方式缓存
|
|
exchanges map[string][]string //交易所缓存
|
|
}
|
|
|
|
// GetName 获取股票名称
|
|
func (this *Codes) GetName(code string) string {
|
|
if v, ok := this.Map[code]; ok {
|
|
return v.Name
|
|
}
|
|
return "未知"
|
|
}
|
|
|
|
// GetStocks 获取股票代码,sh6xxx sz0xx sz30xx
|
|
func (this *Codes) GetStocks(limits ...int) []string {
|
|
limit := conv.DefaultInt(-1, limits...)
|
|
ls := []string(nil)
|
|
for _, m := range this.list {
|
|
code := m.FullCode()
|
|
if protocol.IsStock(code) {
|
|
ls = append(ls, code)
|
|
}
|
|
if limit > 0 && len(ls) >= limit {
|
|
break
|
|
}
|
|
}
|
|
return ls
|
|
}
|
|
|
|
func (this *Codes) Get(code string) *CodeModel {
|
|
return this.Map[code]
|
|
}
|
|
|
|
// GetExchange 获取股票交易所,这里的参数不需要带前缀
|
|
func (this *Codes) GetExchange(code string) protocol.Exchange {
|
|
if len(code) == 6 {
|
|
switch {
|
|
case code[:1] == "6":
|
|
return protocol.ExchangeSH
|
|
case code[:1] == "0":
|
|
return protocol.ExchangeSZ
|
|
case code[:2] == "30":
|
|
return protocol.ExchangeSZ
|
|
}
|
|
}
|
|
var exchange string
|
|
exchanges := this.exchanges[code]
|
|
if len(exchanges) >= 1 {
|
|
exchange = exchanges[0]
|
|
}
|
|
if len(code) == 8 {
|
|
exchange = code[0:2]
|
|
}
|
|
switch exchange {
|
|
case protocol.ExchangeSH.String():
|
|
return protocol.ExchangeSH
|
|
case protocol.ExchangeSZ.String():
|
|
return protocol.ExchangeSZ
|
|
default:
|
|
return protocol.ExchangeSH
|
|
}
|
|
}
|
|
|
|
func (this *Codes) AddExchange(code string) string {
|
|
if exchanges := this.exchanges[code]; len(exchanges) == 1 {
|
|
return exchanges[0] + code
|
|
}
|
|
if len(code) == 6 {
|
|
switch {
|
|
case code[:1] == "6":
|
|
return protocol.ExchangeSH.String() + code
|
|
case code[:1] == "0":
|
|
return protocol.ExchangeSZ.String() + code
|
|
case code[:2] == "30":
|
|
return protocol.ExchangeSZ.String() + code
|
|
}
|
|
return this.GetExchange(code).String() + code
|
|
}
|
|
return code
|
|
}
|
|
|
|
// Update 更新数据,从服务器或者数据库
|
|
func (this *Codes) Update(byDB ...bool) error {
|
|
codes, err := this.GetCodes(len(byDB) > 0 && byDB[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
codeMap := make(map[string]*CodeModel)
|
|
exchanges := make(map[string][]string)
|
|
for _, code := range codes {
|
|
codeMap[code.Exchange+code.Code] = code
|
|
exchanges[code.Code] = append(exchanges[code.Code], code.Exchange)
|
|
}
|
|
this.Map = codeMap
|
|
this.list = codes
|
|
this.exchanges = exchanges
|
|
//更新时间
|
|
_, err = this.db.Update(&UpdateModel{Time: time.Now().Unix()})
|
|
return err
|
|
}
|
|
|
|
// GetCodes 更新股票并返回结果
|
|
func (this *Codes) GetCodes(byDatabase bool) ([]*CodeModel, error) {
|
|
|
|
if this.Client == nil {
|
|
return nil, errors.New("client is nil")
|
|
}
|
|
|
|
//2. 查询数据库所有股票
|
|
list := []*CodeModel(nil)
|
|
if err := this.db.Find(&list); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
//如果是从缓存读取,则返回结果
|
|
if byDatabase {
|
|
return list, nil
|
|
}
|
|
|
|
mCode := make(map[string]*CodeModel, len(list))
|
|
for _, v := range list {
|
|
mCode[v.Code] = v
|
|
}
|
|
|
|
//3. 从服务器获取所有股票代码
|
|
insert := []*CodeModel(nil)
|
|
update := []*CodeModel(nil)
|
|
for _, exchange := range []protocol.Exchange{protocol.ExchangeSH, protocol.ExchangeSZ} {
|
|
resp, err := this.Client.GetCodeAll(exchange)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, v := range resp.List {
|
|
if _, ok := mCode[v.Code]; ok {
|
|
if mCode[v.Code].Name != v.Name {
|
|
mCode[v.Code].Name = v.Name
|
|
update = append(update, &CodeModel{
|
|
Name: v.Name,
|
|
Code: v.Code,
|
|
Exchange: exchange.String(),
|
|
Multiple: v.Multiple,
|
|
Decimal: v.Decimal,
|
|
LastPrice: v.LastPrice,
|
|
})
|
|
}
|
|
} else {
|
|
code := &CodeModel{
|
|
Name: v.Name,
|
|
Code: v.Code,
|
|
Exchange: exchange.String(),
|
|
Multiple: v.Multiple,
|
|
Decimal: v.Decimal,
|
|
LastPrice: v.LastPrice,
|
|
}
|
|
insert = append(insert, code)
|
|
list = append(list, code)
|
|
}
|
|
}
|
|
}
|
|
|
|
//4. 插入或者更新数据库
|
|
err := NewSessionFunc(this.db, func(session *xorm.Session) error {
|
|
for _, v := range insert {
|
|
if _, err := session.Insert(v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, v := range update {
|
|
if _, err := session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Cols("Name,LastPrice").Update(v); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return list, nil
|
|
|
|
}
|
|
|
|
type UpdateModel struct {
|
|
Time int64 //更新时间
|
|
}
|
|
|
|
func (*UpdateModel) TableName() string {
|
|
return "update"
|
|
}
|
|
|
|
type CodeModel struct {
|
|
ID int64 `json:"id"` //主键
|
|
Name string `json:"name"` //名称,有时候名称会变,例STxxx
|
|
Code string `json:"code" xorm:"index"` //代码
|
|
Exchange string `json:"exchange" xorm:"index"` //交易所
|
|
Multiple uint16 `json:"multiple"` //倍数
|
|
Decimal int8 `json:"decimal"` //小数位
|
|
LastPrice float64 `json:"lastPrice"` //昨收价格
|
|
EditDate int64 `json:"editDate" xorm:"updated"` //修改时间
|
|
InDate int64 `json:"inDate" xorm:"created"` //创建时间
|
|
}
|
|
|
|
func (*CodeModel) TableName() string {
|
|
return "codes"
|
|
}
|
|
|
|
func (this *CodeModel) FullCode() string {
|
|
return this.Exchange + this.Code
|
|
}
|
|
|
|
func (this *CodeModel) Price(p protocol.Price) protocol.Price {
|
|
return protocol.Price(float64(p) * math.Pow10(int(2-this.Decimal)))
|
|
return p * protocol.Price(math.Pow10(int(2-this.Decimal)))
|
|
}
|
|
|
|
func NewSessionFunc(db *xorm.Engine, fn func(session *xorm.Session) error) error {
|
|
session := db.NewSession()
|
|
defer session.Close()
|
|
if err := session.Begin(); err != nil {
|
|
session.Rollback()
|
|
return err
|
|
}
|
|
if err := fn(session); err != nil {
|
|
session.Rollback()
|
|
return err
|
|
}
|
|
if err := session.Commit(); err != nil {
|
|
session.Rollback()
|
|
return err
|
|
}
|
|
return nil
|
|
}
|