新增options来配置

This commit is contained in:
bense
2024-10-03 02:11:09 +08:00
parent bfff0f62be
commit 432b27f52b
11 changed files with 119 additions and 99 deletions

View File

@@ -26,18 +26,14 @@ import (
) )
func main() { func main() {
var opt = &gotdx.Opt{ tdx := gotdx.New(gotdx.WithTCPAddress("119.147.212.81:7709"))
Host: "119.147.212.81", _, err := tdx.Connect()
Port: 7709,
}
api := gotdx.NewClient(opt)
connectReply, err := api.Connect()
if err != nil { if err != nil {
log.Println(err) log.Fatalln(err)
} }
log.Println(connectReply.Info) defer tdx.Disconnect()
reply, err := api.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"}) reply, err := tdx.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"})
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
@@ -45,11 +41,9 @@ func main() {
for _, obj := range reply.List { for _, obj := range reply.List {
log.Printf("%+v", obj) log.Printf("%+v", obj)
} }
_ = api.Disconnect()
} }
``` ```

View File

@@ -9,39 +9,29 @@ import (
"io" "io"
"log" "log"
"net" "net"
"strconv" "sync"
"strings"
) )
type Client struct { func New(opts ...Option) *Client {
conn net.Conn
opt *Opt
complete chan bool
sending chan bool
}
type Opt struct {
Host string
Port int
MaxRetryTimes int
}
func NewClient(opt *Opt) *Client {
client := &Client{} client := &Client{}
if opt.MaxRetryTimes <= 0 {
opt.MaxRetryTimes = DefaultRetryTimes
}
client.opt = opt client.opt = applyOptions(opts...)
client.sending = make(chan bool, 1) client.sending = make(chan bool, 1)
client.complete = make(chan bool, 1) client.complete = make(chan bool, 1)
return client return client
} }
type Client struct {
conn net.Conn
opt *Options
complete chan bool
sending chan bool
mu sync.Mutex
}
func (client *Client) connect() error { func (client *Client) connect() error {
addr := strings.Join([]string{client.opt.Host, strconv.Itoa(client.opt.Port)}, ":") conn, err := net.Dial("tcp", client.opt.TCPAddress)
conn, err := net.Dial("tcp", addr)
if err != nil { if err != nil {
return err return err
} }
@@ -112,6 +102,8 @@ func (client *Client) do(msg proto.Msg) error {
// Connect 连接券商行情服务器 // Connect 连接券商行情服务器
func (client *Client) Connect() (*proto.Hello1Reply, error) { func (client *Client) Connect() (*proto.Hello1Reply, error) {
client.mu.Lock()
defer client.mu.Unlock()
err := client.connect() err := client.connect()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -126,11 +118,15 @@ func (client *Client) Connect() (*proto.Hello1Reply, error) {
// Disconnect 断开服务器 // Disconnect 断开服务器
func (client *Client) Disconnect() error { func (client *Client) Disconnect() error {
client.mu.Lock()
defer client.mu.Unlock()
return client.conn.Close() return client.conn.Close()
} }
// GetSecurityCount 获取指定市场内的证券数目 // GetSecurityCount 获取指定市场内的证券数目
func (client *Client) GetSecurityCount(market uint16) (*proto.GetSecurityCountReply, error) { func (client *Client) GetSecurityCount(market uint16) (*proto.GetSecurityCountReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetSecurityCount() obj := proto.NewGetSecurityCount()
obj.SetParams(&proto.GetSecurityCountRequest{ obj.SetParams(&proto.GetSecurityCountRequest{
Market: market, Market: market,
@@ -144,6 +140,8 @@ func (client *Client) GetSecurityCount(market uint16) (*proto.GetSecurityCountRe
// GetSecurityQuotes 获取盘口五档报价 // GetSecurityQuotes 获取盘口五档报价
func (client *Client) GetSecurityQuotes(markets []uint8, codes []string) (*proto.GetSecurityQuotesReply, error) { func (client *Client) GetSecurityQuotes(markets []uint8, codes []string) (*proto.GetSecurityQuotesReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
if len(markets) != len(codes) { if len(markets) != len(codes) {
return nil, errors.New("market code count error") return nil, errors.New("market code count error")
} }
@@ -165,6 +163,8 @@ func (client *Client) GetSecurityQuotes(markets []uint8, codes []string) (*proto
// GetSecurityList 获取市场内指定范围内的所有证券代码 // GetSecurityList 获取市场内指定范围内的所有证券代码
func (client *Client) GetSecurityList(market uint8, start uint16) (*proto.GetSecurityListReply, error) { func (client *Client) GetSecurityList(market uint8, start uint16) (*proto.GetSecurityListReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetSecurityList() obj := proto.NewGetSecurityList()
_market := uint16(market) _market := uint16(market)
obj.SetParams(&proto.GetSecurityListRequest{Market: _market, Start: start}) obj.SetParams(&proto.GetSecurityListRequest{Market: _market, Start: start})
@@ -177,6 +177,8 @@ func (client *Client) GetSecurityList(market uint8, start uint16) (*proto.GetSec
// GetSecurityBars 获取股票K线 // GetSecurityBars 获取股票K线
func (client *Client) GetSecurityBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetSecurityBarsReply, error) { func (client *Client) GetSecurityBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetSecurityBarsReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetSecurityBars() obj := proto.NewGetSecurityBars()
_code := [6]byte{} _code := [6]byte{}
_market := uint16(market) _market := uint16(market)
@@ -197,6 +199,8 @@ func (client *Client) GetSecurityBars(category uint16, market uint8, code string
// GetIndexBars 获取指数K线 // GetIndexBars 获取指数K线
func (client *Client) GetIndexBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetIndexBarsReply, error) { func (client *Client) GetIndexBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetIndexBarsReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetIndexBars() obj := proto.NewGetIndexBars()
_code := [6]byte{} _code := [6]byte{}
_market := uint16(market) _market := uint16(market)
@@ -217,6 +221,8 @@ func (client *Client) GetIndexBars(category uint16, market uint8, code string, s
// GetMinuteTimeData 获取分时图数据 // GetMinuteTimeData 获取分时图数据
func (client *Client) GetMinuteTimeData(market uint8, code string) (*proto.GetMinuteTimeDataReply, error) { func (client *Client) GetMinuteTimeData(market uint8, code string) (*proto.GetMinuteTimeDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetMinuteTimeData() obj := proto.NewGetMinuteTimeData()
_code := [6]byte{} _code := [6]byte{}
_market := uint16(market) _market := uint16(market)
@@ -234,6 +240,8 @@ func (client *Client) GetMinuteTimeData(market uint8, code string) (*proto.GetMi
// GetHistoryMinuteTimeData 获取历史分时图数据 // GetHistoryMinuteTimeData 获取历史分时图数据
func (client *Client) GetHistoryMinuteTimeData(date uint32, market uint8, code string) (*proto.GetHistoryMinuteTimeDataReply, error) { func (client *Client) GetHistoryMinuteTimeData(date uint32, market uint8, code string) (*proto.GetHistoryMinuteTimeDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetHistoryMinuteTimeData() obj := proto.NewGetHistoryMinuteTimeData()
_code := [6]byte{} _code := [6]byte{}
copy(_code[:], code) copy(_code[:], code)
@@ -251,6 +259,8 @@ func (client *Client) GetHistoryMinuteTimeData(date uint32, market uint8, code s
// GetTransactionData 获取分时成交 // GetTransactionData 获取分时成交
func (client *Client) GetTransactionData(market uint8, code string, start uint16, count uint16) (*proto.GetTransactionDataReply, error) { func (client *Client) GetTransactionData(market uint8, code string, start uint16, count uint16) (*proto.GetTransactionDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetTransactionData() obj := proto.NewGetTransactionData()
_code := [6]byte{} _code := [6]byte{}
_market := uint16(market) _market := uint16(market)
@@ -270,6 +280,8 @@ func (client *Client) GetTransactionData(market uint8, code string, start uint16
// GetHistoryTransactionData 获取历史分时成交 // GetHistoryTransactionData 获取历史分时成交
func (client *Client) GetHistoryTransactionData(date uint32, market uint8, code string, start uint16, count uint16) (*proto.GetHistoryTransactionDataReply, error) { func (client *Client) GetHistoryTransactionData(date uint32, market uint8, code string, start uint16, count uint16) (*proto.GetHistoryTransactionDataReply, error) {
client.mu.Lock()
defer client.mu.Unlock()
obj := proto.NewGetHistoryTransactionData() obj := proto.NewGetHistoryTransactionData()
_code := [6]byte{} _code := [6]byte{}
_market := uint16(market) _market := uint16(market)

View File

@@ -6,26 +6,21 @@ import (
"testing" "testing"
) )
var opt = &Opt{
Host: "119.147.212.81",
Port: 7709,
}
func newClient() *Client { func newClient() *Client {
api := NewClient(opt) tdx := New(WithTCPAddress("119.147.212.81:7709"))
reply, err := api.Connect() reply, err := tdx.Connect()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
fmt.Println(reply.Info) fmt.Println(reply.Info)
return api return tdx
} }
func Test_tdx_Connect(t *testing.T) { func Test_tdx_Connect(t *testing.T) {
fmt.Println("================ Connect ================") fmt.Println("================ Connect ================")
api := NewClient(opt) tdx := New(WithTCPAddress("119.147.212.81:7709"))
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.Connect() reply, err := tdx.Connect()
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -34,9 +29,9 @@ func Test_tdx_Connect(t *testing.T) {
func Test_tdx_GetSecurityCount(t *testing.T) { func Test_tdx_GetSecurityCount(t *testing.T) {
fmt.Println("================ GetSecurityCount ================") fmt.Println("================ GetSecurityCount ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetSecurityCount(MarketSh) reply, err := tdx.GetSecurityCount(MarketSh)
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -45,9 +40,9 @@ func Test_tdx_GetSecurityCount(t *testing.T) {
func Test_tdx_GetSecurityQuotes(t *testing.T) { func Test_tdx_GetSecurityQuotes(t *testing.T) {
fmt.Println("================ GetSecurityQuotes ================") fmt.Println("================ GetSecurityQuotes ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetSecurityQuotes([]uint8{MarketSh}, []string{"002062"}) reply, err := tdx.GetSecurityQuotes([]uint8{MarketSh}, []string{"002062"})
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -59,9 +54,9 @@ func Test_tdx_GetSecurityQuotes(t *testing.T) {
func Test_tdx_GetSecurityList(t *testing.T) { func Test_tdx_GetSecurityList(t *testing.T) {
fmt.Println("================ GetSecurityList ================") fmt.Println("================ GetSecurityList ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetSecurityList(MarketSh, 0) reply, err := tdx.GetSecurityList(MarketSh, 0)
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -73,9 +68,9 @@ func Test_tdx_GetSecurityList(t *testing.T) {
func Test_tdx_GetSecurityBars(t *testing.T) { func Test_tdx_GetSecurityBars(t *testing.T) {
fmt.Println("================ GetSecurityBars ================") fmt.Println("================ GetSecurityBars ================")
// GetSecurityBars 与 GetIndexBars 使用同一个接口靠market区分 // GetSecurityBars 与 GetIndexBars 使用同一个接口靠market区分
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetSecurityBars(proto.KLINE_TYPE_RI_K, 0, "000001", 0, 10) reply, err := tdx.GetSecurityBars(proto.KLINE_TYPE_RI_K, 0, "000001", 0, 10)
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -88,9 +83,9 @@ func Test_tdx_GetSecurityBars(t *testing.T) {
func Test_tdx_GetIndexBars(t *testing.T) { func Test_tdx_GetIndexBars(t *testing.T) {
fmt.Println("================ GetIndexBars ================") fmt.Println("================ GetIndexBars ================")
// GetSecurityBars 与 GetIndexBars 使用同一个接口靠market区分 // GetSecurityBars 与 GetIndexBars 使用同一个接口靠market区分
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetIndexBars(proto.KLINE_TYPE_RI_K, 1, "000001", 0, 10) reply, err := tdx.GetIndexBars(proto.KLINE_TYPE_RI_K, 1, "000001", 0, 10)
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -102,9 +97,9 @@ func Test_tdx_GetIndexBars(t *testing.T) {
func Test_tdx_GetMinuteTimeData(t *testing.T) { func Test_tdx_GetMinuteTimeData(t *testing.T) {
fmt.Println("================ GetMinuteTimeData ================") fmt.Println("================ GetMinuteTimeData ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetMinuteTimeData(0, "159607") reply, err := tdx.GetMinuteTimeData(0, "159607")
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -116,10 +111,10 @@ func Test_tdx_GetMinuteTimeData(t *testing.T) {
func Test_tdx_GetHistoryMinuteTimeData(t *testing.T) { func Test_tdx_GetHistoryMinuteTimeData(t *testing.T) {
fmt.Println("================ GetHistoryMinuteTimeData ================") fmt.Println("================ GetHistoryMinuteTimeData ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
//reply, err := api.GetHistoryMinuteTimeData(20220511, 0, "159607") //reply, err := tdx.GetHistoryMinuteTimeData(20220511, 0, "159607")
reply, err := api.GetHistoryMinuteTimeData(20220511, 0, "159607") reply, err := tdx.GetHistoryMinuteTimeData(20220511, 0, "159607")
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -131,10 +126,10 @@ func Test_tdx_GetHistoryMinuteTimeData(t *testing.T) {
func Test_tdx_GetTransactionData(t *testing.T) { func Test_tdx_GetTransactionData(t *testing.T) {
fmt.Println("================ GetTransactionData ================") fmt.Println("================ GetTransactionData ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
//reply, err := api.GetHistoryMinuteTimeData(20220511, 0, "159607") //reply, err := tdx.GetHistoryMinuteTimeData(20220511, 0, "159607")
reply, err := api.GetTransactionData(MarketSh, "159607", 0, 10) reply, err := tdx.GetTransactionData(MarketSh, "159607", 0, 10)
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }
@@ -146,9 +141,9 @@ func Test_tdx_GetTransactionData(t *testing.T) {
func Test_tdx_GetHistoryTransactionData(t *testing.T) { func Test_tdx_GetHistoryTransactionData(t *testing.T) {
fmt.Println("================ GetHistoryTransactionData ================") fmt.Println("================ GetHistoryTransactionData ================")
api := newClient() tdx := newClient()
defer api.Disconnect() defer tdx.Disconnect()
reply, err := api.GetHistoryTransactionData(20230922, MarketSh, "159607", 0, 10) reply, err := tdx.GetHistoryTransactionData(20230922, MarketSh, "159607", 0, 10)
if err != nil { if err != nil {
t.Errorf("error:%s", err) t.Errorf("error:%s", err)
} }

View File

@@ -23,10 +23,6 @@ const (
KLINE_TYPE_YEARLY = 11 // 年K 线 KLINE_TYPE_YEARLY = 11 // 年K 线
) )
const (
DefaultRetryTimes = 3 // 重试次数
)
var ( var (
ErrBadData = errors.New("more than 8M data") ErrBadData = errors.New("more than 8M data")
) )

View File

@@ -6,18 +6,14 @@ import (
) )
func main() { func main() {
var opt = &gotdx.Opt{ tdx := gotdx.New(gotdx.WithTCPAddress("119.147.212.81:7709"))
Host: "119.147.212.81", _, err := tdx.Connect()
Port: 7709,
}
api := gotdx.NewClient(opt)
connectReply, err := api.Connect()
if err != nil { if err != nil {
log.Println(err) log.Fatalln(err)
} }
log.Println(connectReply.Info) defer tdx.Disconnect()
reply, err := api.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"}) reply, err := tdx.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"})
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
@@ -25,7 +21,4 @@ func main() {
for _, obj := range reply.List { for _, obj := range reply.List {
log.Printf("%+v", obj) log.Printf("%+v", obj)
} }
_ = api.Disconnect()
} }

34
options.go Normal file
View File

@@ -0,0 +1,34 @@
package gotdx
const (
_defaultTCPAddress = "119.147.212.81:7709"
_defaultRetryTimes = 3
)
type Options struct {
TCPAddress string // 服务器地址
MaxRetryTimes int // 重试次数
}
func defaultOptions() *Options {
return &Options{
TCPAddress: _defaultTCPAddress,
MaxRetryTimes: _defaultRetryTimes,
}
}
func applyOptions(opts ...Option) *Options {
o := defaultOptions()
for _, opt := range opts {
opt(o)
}
return o
}
type Option func(options *Options)
func WithTCPAddress(tcpAddress string) Option {
return func(o *Options) {
o.TCPAddress = tcpAddress
}
}

View File

@@ -3,7 +3,6 @@ package proto
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"github.com/bensema/gotdx/util"
) )
type GetSecurityList struct { type GetSecurityList struct {
@@ -88,7 +87,7 @@ func (obj *GetSecurityList) UnSerialize(header interface{}, data []byte) error {
binary.Read(bytes.NewBuffer(data[pos:pos+8]), binary.LittleEndian, &name) binary.Read(bytes.NewBuffer(data[pos:pos+8]), binary.LittleEndian, &name)
pos += 8 pos += 8
ele.Name = util.Utf8ToGbk(name[:]) ele.Name = Utf8ToGbk(name[:])
pos += 4 pos += 4
binary.Read(bytes.NewBuffer(data[pos:pos+1]), binary.LittleEndian, &ele.DecimalPoint) binary.Read(bytes.NewBuffer(data[pos:pos+1]), binary.LittleEndian, &ele.DecimalPoint)

View File

@@ -5,7 +5,6 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"github.com/bensema/gotdx/util"
) )
type GetSecurityQuotes struct { type GetSecurityQuotes struct {
@@ -146,7 +145,7 @@ func (obj *GetSecurityQuotes) UnSerialize(header interface{}, data []byte) error
binary.Read(bytes.NewBuffer(data[pos:pos+6]), binary.LittleEndian, &code) binary.Read(bytes.NewBuffer(data[pos:pos+6]), binary.LittleEndian, &code)
//enc := mahonia.NewDecoder("gbk") //enc := mahonia.NewDecoder("gbk")
//ele.Code = enc.ConvertString(string(code[:])) //ele.Code = enc.ConvertString(string(code[:]))
ele.Code = util.Utf8ToGbk(code[:]) ele.Code = Utf8ToGbk(code[:])
pos += 6 pos += 6
binary.Read(bytes.NewBuffer(data[pos:pos+2]), binary.LittleEndian, &ele.Active1) binary.Read(bytes.NewBuffer(data[pos:pos+2]), binary.LittleEndian, &ele.Active1)
pos += 2 pos += 2

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"github.com/bensema/gotdx/util"
) )
type Hello1 struct { type Hello1 struct {
@@ -60,7 +59,7 @@ func (obj *Hello1) Serialize() ([]byte, error) {
func (obj *Hello1) UnSerialize(header interface{}, data []byte) error { func (obj *Hello1) UnSerialize(header interface{}, data []byte) error {
obj.respHeader = header.(*RespHeader) obj.respHeader = header.(*RespHeader)
serverInfo := util.Utf8ToGbk(data[68:]) serverInfo := Utf8ToGbk(data[68:])
obj.reply.Info = serverInfo obj.reply.Info = serverInfo
return nil return nil

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"github.com/bensema/gotdx/util"
) )
type Hello2 struct { type Hello2 struct {
@@ -57,7 +56,7 @@ func (obj *Hello2) Serialize() ([]byte, error) {
func (obj *Hello2) UnSerialize(header interface{}, data []byte) error { func (obj *Hello2) UnSerialize(header interface{}, data []byte) error {
obj.respHeader = header.(*RespHeader) obj.respHeader = header.(*RespHeader)
serverInfo := util.Utf8ToGbk(data[58:]) serverInfo := Utf8ToGbk(data[58:])
//fmt.Println(hex.EncodeToString(data)) //fmt.Println(hex.EncodeToString(data))
obj.reply.Info = serverInfo obj.reply.Info = serverInfo
return nil return nil

View File

@@ -1,4 +1,4 @@
package util package proto
import ( import (
"bytes" "bytes"