Files
bensema-gotdx/client.go
2024-10-03 02:11:09 +08:00

302 lines
6.9 KiB
Go

package gotdx
import (
"bytes"
"compress/zlib"
"encoding/binary"
"errors"
"github.com/bensema/gotdx/proto"
"io"
"log"
"net"
"sync"
)
func New(opts ...Option) *Client {
client := &Client{}
client.opt = applyOptions(opts...)
client.sending = make(chan bool, 1)
client.complete = make(chan bool, 1)
return client
}
type Client struct {
conn net.Conn
opt *Options
complete chan bool
sending chan bool
mu sync.Mutex
}
func (client *Client) connect() error {
conn, err := net.Dial("tcp", client.opt.TCPAddress)
if err != nil {
return err
}
client.conn = conn
return err
}
func (client *Client) do(msg proto.Msg) error {
sendData, err := msg.Serialize()
if err != nil {
return err
}
retryTimes := 0
for {
n, err := client.conn.Write(sendData)
if n < len(sendData) {
retryTimes++
if retryTimes <= client.opt.MaxRetryTimes {
log.Printf("第%d次重试\n", retryTimes)
} else {
return err
}
} else {
if err != nil {
return err
}
break
}
}
headerBytes := make([]byte, proto.MessageHeaderBytes)
_, err = io.ReadFull(client.conn, headerBytes)
if err != nil {
return err
}
headerBuf := bytes.NewReader(headerBytes)
var header proto.RespHeader
if err := binary.Read(headerBuf, binary.LittleEndian, &header); err != nil {
return err
}
if header.ZipSize > proto.MessageMaxBytes {
log.Printf("msgData has bytes(%d) beyond max %d\n", header.ZipSize, proto.MessageMaxBytes)
return ErrBadData
}
msgData := make([]byte, header.ZipSize)
_, err = io.ReadFull(client.conn, msgData)
if err != nil {
return err
}
var out bytes.Buffer
if header.ZipSize != header.UnZipSize {
b := bytes.NewReader(msgData)
r, _ := zlib.NewReader(b)
io.Copy(&out, r)
err = msg.UnSerialize(&header, out.Bytes())
} else {
err = msg.UnSerialize(&header, msgData)
}
return err
}
// Connect 连接券商行情服务器
func (client *Client) Connect() (*proto.Hello1Reply, error) {
client.mu.Lock()
defer client.mu.Unlock()
err := client.connect()
if err != nil {
return nil, err
}
obj := proto.NewHello1()
err = client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// Disconnect 断开服务器
func (client *Client) Disconnect() error {
client.mu.Lock()
defer client.mu.Unlock()
return client.conn.Close()
}
// GetSecurityCount 获取指定市场内的证券数目
func (client *Client) GetSecurityCount(market uint16) (*proto.GetSecurityCountReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetSecurityCount()
obj.SetParams(&proto.GetSecurityCountRequest{
Market: market,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetSecurityQuotes 获取盘口五档报价
func (client *Client) GetSecurityQuotes(markets []uint8, codes []string) (*proto.GetSecurityQuotesReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
if len(markets) != len(codes) {
return nil, errors.New("market code count error")
}
obj := proto.NewGetSecurityQuotes()
var params []proto.Stock
for i, market := range markets {
params = append(params, proto.Stock{
Market: market,
Code: codes[i],
})
}
obj.SetParams(&proto.GetSecurityQuotesRequest{StockList: params})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetSecurityList 获取市场内指定范围内的所有证券代码
func (client *Client) GetSecurityList(market uint8, start uint16) (*proto.GetSecurityListReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetSecurityList()
_market := uint16(market)
obj.SetParams(&proto.GetSecurityListRequest{Market: _market, Start: start})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetSecurityBars 获取股票K线
func (client *Client) GetSecurityBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetSecurityBarsReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetSecurityBars()
_code := [6]byte{}
_market := uint16(market)
copy(_code[:], code)
obj.SetParams(&proto.GetSecurityBarsRequest{
Market: _market,
Code: _code,
Category: category,
Start: start,
Count: count,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetIndexBars 获取指数K线
func (client *Client) GetIndexBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetIndexBarsReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetIndexBars()
_code := [6]byte{}
_market := uint16(market)
copy(_code[:], code)
obj.SetParams(&proto.GetIndexBarsRequest{
Market: _market,
Code: _code,
Category: category,
Start: start,
Count: count,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetMinuteTimeData 获取分时图数据
func (client *Client) GetMinuteTimeData(market uint8, code string) (*proto.GetMinuteTimeDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetMinuteTimeData()
_code := [6]byte{}
_market := uint16(market)
copy(_code[:], code)
obj.SetParams(&proto.GetMinuteTimeDataRequest{
Market: _market,
Code: _code,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetHistoryMinuteTimeData 获取历史分时图数据
func (client *Client) GetHistoryMinuteTimeData(date uint32, market uint8, code string) (*proto.GetHistoryMinuteTimeDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetHistoryMinuteTimeData()
_code := [6]byte{}
copy(_code[:], code)
obj.SetParams(&proto.GetHistoryMinuteTimeDataRequest{
Date: date,
Market: market,
Code: _code,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetTransactionData 获取分时成交
func (client *Client) GetTransactionData(market uint8, code string, start uint16, count uint16) (*proto.GetTransactionDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetTransactionData()
_code := [6]byte{}
_market := uint16(market)
copy(_code[:], code)
obj.SetParams(&proto.GetTransactionDataRequest{
Market: _market,
Code: _code,
Start: start,
Count: count,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}
// GetHistoryTransactionData 获取历史分时成交
func (client *Client) GetHistoryTransactionData(date uint32, market uint8, code string, start uint16, count uint16) (*proto.GetHistoryTransactionDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetHistoryTransactionData()
_code := [6]byte{}
_market := uint16(market)
copy(_code[:], code)
obj.SetParams(&proto.GetHistoryTransactionDataRequest{
Date: date,
Market: _market,
Code: _code,
Start: start,
Count: count,
})
err := client.do(obj)
if err != nil {
return nil, err
}
return obj.Reply(), err
}