mirror of
https://github.com/injoyai/tdx.git
synced 2025-11-26 21:25:35 +08:00
403 lines
9.5 KiB
Go
403 lines
9.5 KiB
Go
package tdx
|
||
|
||
import (
|
||
"errors"
|
||
"github.com/injoyai/conv"
|
||
"github.com/injoyai/ios/client"
|
||
"github.com/injoyai/logs"
|
||
"github.com/injoyai/tdx/protocol"
|
||
"github.com/robfig/cron/v3"
|
||
"iter"
|
||
"math"
|
||
"os"
|
||
"path/filepath"
|
||
"time"
|
||
"xorm.io/core"
|
||
"xorm.io/xorm"
|
||
)
|
||
|
||
type ICodes interface {
|
||
Iter() iter.Seq2[string, *CodeModel]
|
||
Get(code string) *CodeModel
|
||
GetName(code string) string
|
||
GetStocks(limit ...int) CodeModels
|
||
GetStockCodes(limit ...int) []string
|
||
GetETFs(limit ...int) CodeModels
|
||
GetETFCodes(limit ...int) []string
|
||
}
|
||
|
||
// DefaultCodes 增加单例,部分数据需要通过Codes里面的信息计算
|
||
var DefaultCodes *Codes
|
||
|
||
func DialCodes(filename string, op ...client.Option) (*Codes, error) {
|
||
c, err := DialDefault(op...)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return NewCodesSqlite(c, filename)
|
||
}
|
||
|
||
func NewCodesMysql(c *Client, dsn string) (*Codes, error) {
|
||
|
||
//连接数据库
|
||
db, err := xorm.NewEngine("mysql", dsn)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
db.SetMapper(core.SameMapper{})
|
||
|
||
return NewCodes(c, db)
|
||
}
|
||
|
||
func NewCodesSqlite(c *Client, filenames ...string) (*Codes, error) {
|
||
|
||
//如果没有指定文件名,则使用默认
|
||
defaultFilename := filepath.Join(DefaultDatabaseDir, "codes.db")
|
||
filename := conv.Default(defaultFilename, filenames...)
|
||
filename = conv.Select(filename == "", defaultFilename, filename)
|
||
|
||
//如果文件夹不存在就创建
|
||
dir, _ := filepath.Split(filename)
|
||
_ = os.MkdirAll(dir, 0777)
|
||
|
||
//连接数据库
|
||
db, err := xorm.NewEngine("sqlite", filename)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
db.SetMapper(core.SameMapper{})
|
||
db.DB().SetMaxOpenConns(1)
|
||
|
||
return NewCodes(c, db)
|
||
}
|
||
|
||
func NewCodes(c *Client, db *xorm.Engine) (*Codes, error) {
|
||
|
||
if err := db.Sync2(new(CodeModel)); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := db.Sync2(new(UpdateModel)); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
update := new(UpdateModel)
|
||
{ //查询或者插入一条数据
|
||
has, err := db.Where("`Key`=?", "codes").Get(update)
|
||
if err != nil {
|
||
return nil, err
|
||
} else if !has {
|
||
update.Key = "codes"
|
||
if _, err := db.Insert(update); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
}
|
||
|
||
cc := &Codes{
|
||
Client: c,
|
||
db: db,
|
||
}
|
||
|
||
{ //设置定时器,每天早上9点更新数据
|
||
task := cron.New(cron.WithSeconds())
|
||
task.AddFunc("10 0 9 * * *", func() {
|
||
for i := 0; i < 3; i++ {
|
||
err := cc.Update()
|
||
if err == nil {
|
||
return
|
||
}
|
||
logs.Err(err)
|
||
<-time.After(time.Minute * 5)
|
||
}
|
||
})
|
||
task.Start()
|
||
}
|
||
|
||
{ //判断是否更新过,更新过则不更新
|
||
now := time.Now()
|
||
node := time.Date(now.Year(), now.Month(), now.Day(), 9, 0, 0, 0, time.Local)
|
||
updateTime := time.Unix(update.Time, 0)
|
||
if now.Sub(node) > 0 {
|
||
//当前时间在9点之后,且更新时间在9点之前,需要更新
|
||
if updateTime.Sub(node) < 0 {
|
||
return cc, cc.Update()
|
||
}
|
||
} else {
|
||
//当前时间在9点之前,且更新时间在上个节点之前
|
||
if updateTime.Sub(node.Add(time.Hour*24)) < 0 {
|
||
return cc, cc.Update()
|
||
}
|
||
}
|
||
}
|
||
|
||
//从缓存中加载
|
||
return cc, cc.Update(true)
|
||
}
|
||
|
||
type Codes struct {
|
||
*Client //客户端
|
||
db *xorm.Engine //数据库实例
|
||
Map map[string]*CodeModel //股票缓存
|
||
list []*CodeModel //列表方式缓存
|
||
exchanges map[string][]string //交易所缓存
|
||
}
|
||
|
||
func (this *Codes) Get(code string) *CodeModel {
|
||
return this.Map[code]
|
||
}
|
||
|
||
func (this *Codes) Iter() iter.Seq2[string, *CodeModel] {
|
||
return func(yield func(string, *CodeModel) bool) {
|
||
for _, code := range this.list {
|
||
if !yield(code.FullCode(), code) {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetName 获取股票名称
|
||
func (this *Codes) GetName(code string) string {
|
||
if v, ok := this.Map[code]; ok {
|
||
return v.Name
|
||
}
|
||
return "未知"
|
||
}
|
||
|
||
// GetStocks 获取股票代码,sh6xxx sz0xx sz30xx
|
||
func (this *Codes) GetStocks(limits ...int) CodeModels {
|
||
limit := conv.Default(-1, limits...)
|
||
ls := []*CodeModel(nil)
|
||
for _, m := range this.list {
|
||
code := m.FullCode()
|
||
if protocol.IsStock(code) {
|
||
ls = append(ls, m)
|
||
}
|
||
if limit > 0 && len(ls) >= limit {
|
||
break
|
||
}
|
||
}
|
||
return ls
|
||
}
|
||
|
||
func (this *Codes) GetStockCodes(limits ...int) []string {
|
||
return this.GetStocks(limits...).Codes()
|
||
}
|
||
|
||
// GetETFs 获取基金代码,sz159xxx,sh510xxx,sh511xxx
|
||
func (this *Codes) GetETFs(limits ...int) CodeModels {
|
||
limit := conv.Default(-1, limits...)
|
||
ls := []*CodeModel(nil)
|
||
for _, m := range this.list {
|
||
code := m.FullCode()
|
||
if protocol.IsETF(code) {
|
||
ls = append(ls, m)
|
||
}
|
||
if limit > 0 && len(ls) >= limit {
|
||
break
|
||
}
|
||
}
|
||
return ls
|
||
}
|
||
|
||
// GetETFCodes 获取基金代码,sz159xxx,sh510xxx,sh511xxx
|
||
func (this *Codes) GetETFCodes(limits ...int) []string {
|
||
return this.GetETFs(limits...).Codes()
|
||
}
|
||
|
||
func (this *Codes) AddExchange(code string) string {
|
||
return protocol.AddPrefix(code)
|
||
}
|
||
|
||
// Update 更新数据,从服务器或者数据库
|
||
func (this *Codes) Update(byDB ...bool) error {
|
||
codes, err := this.GetCodes(len(byDB) > 0 && byDB[0])
|
||
if err != nil {
|
||
return err
|
||
}
|
||
codeMap := make(map[string]*CodeModel)
|
||
exchanges := make(map[string][]string)
|
||
for _, code := range codes {
|
||
codeMap[code.Exchange+code.Code] = code
|
||
exchanges[code.Code] = append(exchanges[code.Code], code.Exchange)
|
||
}
|
||
this.Map = codeMap
|
||
this.list = codes
|
||
this.exchanges = exchanges
|
||
//更新时间
|
||
_, err = this.db.Where("`Key`=?", "codes").Update(&UpdateModel{Time: time.Now().Unix()})
|
||
return err
|
||
}
|
||
|
||
// GetCodes 更新股票并返回结果
|
||
func (this *Codes) GetCodes(byDatabase bool) ([]*CodeModel, error) {
|
||
|
||
if this.Client == nil {
|
||
return nil, errors.New("client is nil")
|
||
}
|
||
|
||
//2. 查询数据库所有股票
|
||
list := []*CodeModel(nil)
|
||
if err := this.db.Find(&list); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
//如果是从缓存读取,则返回结果
|
||
if byDatabase {
|
||
return list, nil
|
||
}
|
||
|
||
mCode := make(map[string]*CodeModel, len(list))
|
||
for _, v := range list {
|
||
mCode[v.Code] = v
|
||
}
|
||
|
||
//3. 从服务器获取所有股票代码
|
||
insert := []*CodeModel(nil)
|
||
update := []*CodeModel(nil)
|
||
for _, exchange := range []protocol.Exchange{protocol.ExchangeSH, protocol.ExchangeSZ, protocol.ExchangeBJ} {
|
||
resp, err := this.Client.GetCodeAll(exchange)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, v := range resp.List {
|
||
if _, ok := mCode[v.Code]; ok {
|
||
if mCode[v.Code].Name != v.Name {
|
||
mCode[v.Code].Name = v.Name
|
||
update = append(update, &CodeModel{
|
||
Name: v.Name,
|
||
Code: v.Code,
|
||
Exchange: exchange.String(),
|
||
Multiple: v.Multiple,
|
||
Decimal: v.Decimal,
|
||
LastPrice: v.LastPrice,
|
||
})
|
||
}
|
||
} else {
|
||
code := &CodeModel{
|
||
Name: v.Name,
|
||
Code: v.Code,
|
||
Exchange: exchange.String(),
|
||
Multiple: v.Multiple,
|
||
Decimal: v.Decimal,
|
||
LastPrice: v.LastPrice,
|
||
}
|
||
insert = append(insert, code)
|
||
list = append(list, code)
|
||
}
|
||
}
|
||
}
|
||
|
||
switch this.db.Dialect().URI().DBType {
|
||
case "mysql":
|
||
// 1️⃣ 清空
|
||
if _, err := this.db.Exec("TRUNCATE TABLE codes"); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
data := append(insert, update...)
|
||
// 2️⃣ 直接批量插入
|
||
batchSize := 3000 // 8000(2m16s) 5000(43s) 3000(11s) 1000(59s)
|
||
for i := 0; i < len(data); i += batchSize {
|
||
end := i + batchSize
|
||
if end > len(data) {
|
||
end = len(data)
|
||
}
|
||
|
||
slice := conv.Array(data[i:end])
|
||
if _, err := this.db.Insert(slice); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
case "sqlite3":
|
||
//4. 插入或者更新数据库
|
||
err := NewSessionFunc(this.db, func(session *xorm.Session) error {
|
||
for _, v := range insert {
|
||
if _, err := session.Insert(v); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
for _, v := range update {
|
||
if _, err := session.Where("Exchange=? and Code=? ", v.Exchange, v.Code).Cols("Name,LastPrice").Update(v); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
return list, nil
|
||
}
|
||
|
||
type UpdateModel struct {
|
||
Key string
|
||
Time int64 //更新时间
|
||
}
|
||
|
||
func (*UpdateModel) TableName() string {
|
||
return "update"
|
||
}
|
||
|
||
type CodeModel struct {
|
||
ID int64 `json:"id"` //主键
|
||
Name string `json:"name"` //名称,有时候名称会变,例STxxx
|
||
Code string `json:"code" xorm:"index"` //代码
|
||
Exchange string `json:"exchange" xorm:"index"` //交易所
|
||
Multiple uint16 `json:"multiple"` //倍数
|
||
Decimal int8 `json:"decimal"` //小数位
|
||
LastPrice float64 `json:"lastPrice"` //昨收价格
|
||
FloatStock float64 `json:"floatStock"` //流通股
|
||
TotalStock float64 `json:"totalStock"` //总股本
|
||
EditDate int64 `json:"editDate" xorm:"updated"` //修改时间
|
||
InDate int64 `json:"inDate" xorm:"created"` //创建时间
|
||
}
|
||
|
||
func (*CodeModel) TableName() string {
|
||
return "codes"
|
||
}
|
||
|
||
// FullCode 获取完整代码 sz000001
|
||
func (this *CodeModel) FullCode() string {
|
||
return this.Exchange + this.Code
|
||
}
|
||
|
||
// Turnover 换手率
|
||
func (this *CodeModel) Turnover(volume float64) float64 {
|
||
return volume / (this.FloatStock * 100)
|
||
}
|
||
|
||
func (this *CodeModel) Price(p protocol.Price) protocol.Price {
|
||
return protocol.Price(float64(p) * math.Pow10(int(2-this.Decimal)))
|
||
}
|
||
|
||
func NewSessionFunc(db *xorm.Engine, fn func(session *xorm.Session) error) error {
|
||
session := db.NewSession()
|
||
defer session.Close()
|
||
if err := session.Begin(); err != nil {
|
||
session.Rollback()
|
||
return err
|
||
}
|
||
if err := fn(session); err != nil {
|
||
session.Rollback()
|
||
return err
|
||
}
|
||
if err := session.Commit(); err != nil {
|
||
session.Rollback()
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type CodeModels []*CodeModel
|
||
|
||
func (this CodeModels) Codes() []string {
|
||
codes := make([]string, len(this))
|
||
for i, v := range this {
|
||
codes[i] = v.FullCode()
|
||
}
|
||
return codes
|
||
}
|