From 432b27f52bc19ad46a91cef7beeedebd2d41a8fc Mon Sep 17 00:00:00 2001 From: bense Date: Thu, 3 Oct 2024 02:11:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Eoptions=E6=9D=A5=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 18 +++----- client.go | 56 ++++++++++++++---------- client_test.go | 75 +++++++++++++++------------------ constants.go | 4 -- examples/main.go | 17 +++----- options.go | 34 +++++++++++++++ proto/get_security_list.go | 3 +- proto/get_security_quotes.go | 3 +- proto/hello1.go | 3 +- proto/hello2.go | 3 +- util/string.go => proto/util.go | 2 +- 11 files changed, 119 insertions(+), 99 deletions(-) create mode 100644 options.go rename util/string.go => proto/util.go (96%) diff --git a/README.md b/README.md index dc8cf56..0d4f970 100644 --- a/README.md +++ b/README.md @@ -26,18 +26,14 @@ import ( ) func main() { - var opt = &gotdx.Opt{ - Host: "119.147.212.81", - Port: 7709, - } - api := gotdx.NewClient(opt) - connectReply, err := api.Connect() + tdx := gotdx.New(gotdx.WithTCPAddress("119.147.212.81:7709")) + _, err := tdx.Connect() if err != nil { - log.Println(err) + log.Fatalln(err) } - log.Println(connectReply.Info) + defer tdx.Disconnect() - reply, err := api.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"}) + reply, err := tdx.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"}) if err != nil { log.Println(err) } @@ -45,11 +41,9 @@ func main() { for _, obj := range reply.List { log.Printf("%+v", obj) } - - _ = api.Disconnect() - } + ``` diff --git a/client.go b/client.go index c1977fd..b20177f 100644 --- a/client.go +++ b/client.go @@ -9,39 +9,29 @@ import ( "io" "log" "net" - "strconv" - "strings" + "sync" ) -type Client struct { - conn net.Conn - opt *Opt - complete chan bool - sending chan bool -} - -type Opt struct { - Host string - Port int - MaxRetryTimes int -} - -func NewClient(opt *Opt) *Client { +func New(opts ...Option) *Client { client := &Client{} - if opt.MaxRetryTimes <= 0 { - opt.MaxRetryTimes = DefaultRetryTimes - } - client.opt = opt + client.opt = applyOptions(opts...) client.sending = make(chan bool, 1) client.complete = make(chan bool, 1) return client } +type Client struct { + conn net.Conn + opt *Options + complete chan bool + sending chan bool + mu sync.Mutex +} + func (client *Client) connect() error { - addr := strings.Join([]string{client.opt.Host, strconv.Itoa(client.opt.Port)}, ":") - conn, err := net.Dial("tcp", addr) + conn, err := net.Dial("tcp", client.opt.TCPAddress) if err != nil { return err } @@ -112,6 +102,8 @@ func (client *Client) do(msg proto.Msg) error { // Connect 连接券商行情服务器 func (client *Client) Connect() (*proto.Hello1Reply, error) { + client.mu.Lock() + defer client.mu.Unlock() err := client.connect() if err != nil { return nil, err @@ -126,11 +118,15 @@ func (client *Client) Connect() (*proto.Hello1Reply, error) { // Disconnect 断开服务器 func (client *Client) Disconnect() error { + client.mu.Lock() + defer client.mu.Unlock() return client.conn.Close() } // GetSecurityCount 获取指定市场内的证券数目 func (client *Client) GetSecurityCount(market uint16) (*proto.GetSecurityCountReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetSecurityCount() obj.SetParams(&proto.GetSecurityCountRequest{ Market: market, @@ -144,6 +140,8 @@ func (client *Client) GetSecurityCount(market uint16) (*proto.GetSecurityCountRe // GetSecurityQuotes 获取盘口五档报价 func (client *Client) GetSecurityQuotes(markets []uint8, codes []string) (*proto.GetSecurityQuotesReply, error) { + client.mu.Lock() + defer client.mu.Unlock() if len(markets) != len(codes) { return nil, errors.New("market code count error") } @@ -165,6 +163,8 @@ func (client *Client) GetSecurityQuotes(markets []uint8, codes []string) (*proto // GetSecurityList 获取市场内指定范围内的所有证券代码 func (client *Client) GetSecurityList(market uint8, start uint16) (*proto.GetSecurityListReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetSecurityList() _market := uint16(market) obj.SetParams(&proto.GetSecurityListRequest{Market: _market, Start: start}) @@ -177,6 +177,8 @@ func (client *Client) GetSecurityList(market uint8, start uint16) (*proto.GetSec // GetSecurityBars 获取股票K线 func (client *Client) GetSecurityBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetSecurityBarsReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetSecurityBars() _code := [6]byte{} _market := uint16(market) @@ -197,6 +199,8 @@ func (client *Client) GetSecurityBars(category uint16, market uint8, code string // GetIndexBars 获取指数K线 func (client *Client) GetIndexBars(category uint16, market uint8, code string, start uint16, count uint16) (*proto.GetIndexBarsReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetIndexBars() _code := [6]byte{} _market := uint16(market) @@ -217,6 +221,8 @@ func (client *Client) GetIndexBars(category uint16, market uint8, code string, s // GetMinuteTimeData 获取分时图数据 func (client *Client) GetMinuteTimeData(market uint8, code string) (*proto.GetMinuteTimeDataReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetMinuteTimeData() _code := [6]byte{} _market := uint16(market) @@ -234,6 +240,8 @@ func (client *Client) GetMinuteTimeData(market uint8, code string) (*proto.GetMi // GetHistoryMinuteTimeData 获取历史分时图数据 func (client *Client) GetHistoryMinuteTimeData(date uint32, market uint8, code string) (*proto.GetHistoryMinuteTimeDataReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetHistoryMinuteTimeData() _code := [6]byte{} copy(_code[:], code) @@ -251,6 +259,8 @@ func (client *Client) GetHistoryMinuteTimeData(date uint32, market uint8, code s // GetTransactionData 获取分时成交 func (client *Client) GetTransactionData(market uint8, code string, start uint16, count uint16) (*proto.GetTransactionDataReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetTransactionData() _code := [6]byte{} _market := uint16(market) @@ -270,6 +280,8 @@ func (client *Client) GetTransactionData(market uint8, code string, start uint16 // GetHistoryTransactionData 获取历史分时成交 func (client *Client) GetHistoryTransactionData(date uint32, market uint8, code string, start uint16, count uint16) (*proto.GetHistoryTransactionDataReply, error) { + client.mu.Lock() + defer client.mu.Unlock() obj := proto.NewGetHistoryTransactionData() _code := [6]byte{} _market := uint16(market) diff --git a/client_test.go b/client_test.go index 7b212f0..e75c57e 100644 --- a/client_test.go +++ b/client_test.go @@ -6,26 +6,21 @@ import ( "testing" ) -var opt = &Opt{ - Host: "119.147.212.81", - Port: 7709, -} - func newClient() *Client { - api := NewClient(opt) - reply, err := api.Connect() + tdx := New(WithTCPAddress("119.147.212.81:7709")) + reply, err := tdx.Connect() if err != nil { fmt.Println(err) } fmt.Println(reply.Info) - return api + return tdx } func Test_tdx_Connect(t *testing.T) { fmt.Println("================ Connect ================") - api := NewClient(opt) - defer api.Disconnect() - reply, err := api.Connect() + tdx := New(WithTCPAddress("119.147.212.81:7709")) + defer tdx.Disconnect() + reply, err := tdx.Connect() if err != nil { t.Errorf("error:%s", err) } @@ -34,9 +29,9 @@ func Test_tdx_Connect(t *testing.T) { func Test_tdx_GetSecurityCount(t *testing.T) { fmt.Println("================ GetSecurityCount ================") - api := newClient() - defer api.Disconnect() - reply, err := api.GetSecurityCount(MarketSh) + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetSecurityCount(MarketSh) if err != nil { t.Errorf("error:%s", err) } @@ -45,9 +40,9 @@ func Test_tdx_GetSecurityCount(t *testing.T) { func Test_tdx_GetSecurityQuotes(t *testing.T) { fmt.Println("================ GetSecurityQuotes ================") - api := newClient() - defer api.Disconnect() - reply, err := api.GetSecurityQuotes([]uint8{MarketSh}, []string{"002062"}) + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetSecurityQuotes([]uint8{MarketSh}, []string{"002062"}) if err != nil { t.Errorf("error:%s", err) } @@ -59,9 +54,9 @@ func Test_tdx_GetSecurityQuotes(t *testing.T) { func Test_tdx_GetSecurityList(t *testing.T) { fmt.Println("================ GetSecurityList ================") - api := newClient() - defer api.Disconnect() - reply, err := api.GetSecurityList(MarketSh, 0) + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetSecurityList(MarketSh, 0) if err != nil { t.Errorf("error:%s", err) } @@ -73,9 +68,9 @@ func Test_tdx_GetSecurityList(t *testing.T) { func Test_tdx_GetSecurityBars(t *testing.T) { fmt.Println("================ GetSecurityBars ================") // GetSecurityBars 与 GetIndexBars 使用同一个接口靠market区分 - api := newClient() - defer api.Disconnect() - reply, err := api.GetSecurityBars(proto.KLINE_TYPE_RI_K, 0, "000001", 0, 10) + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetSecurityBars(proto.KLINE_TYPE_RI_K, 0, "000001", 0, 10) if err != nil { t.Errorf("error:%s", err) } @@ -88,9 +83,9 @@ func Test_tdx_GetSecurityBars(t *testing.T) { func Test_tdx_GetIndexBars(t *testing.T) { fmt.Println("================ GetIndexBars ================") // GetSecurityBars 与 GetIndexBars 使用同一个接口靠market区分 - api := newClient() - defer api.Disconnect() - reply, err := api.GetIndexBars(proto.KLINE_TYPE_RI_K, 1, "000001", 0, 10) + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetIndexBars(proto.KLINE_TYPE_RI_K, 1, "000001", 0, 10) if err != nil { t.Errorf("error:%s", err) } @@ -102,9 +97,9 @@ func Test_tdx_GetIndexBars(t *testing.T) { func Test_tdx_GetMinuteTimeData(t *testing.T) { fmt.Println("================ GetMinuteTimeData ================") - api := newClient() - defer api.Disconnect() - reply, err := api.GetMinuteTimeData(0, "159607") + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetMinuteTimeData(0, "159607") if err != nil { t.Errorf("error:%s", err) } @@ -116,10 +111,10 @@ func Test_tdx_GetMinuteTimeData(t *testing.T) { func Test_tdx_GetHistoryMinuteTimeData(t *testing.T) { fmt.Println("================ GetHistoryMinuteTimeData ================") - api := newClient() - defer api.Disconnect() - //reply, err := api.GetHistoryMinuteTimeData(20220511, 0, "159607") - reply, err := api.GetHistoryMinuteTimeData(20220511, 0, "159607") + tdx := newClient() + defer tdx.Disconnect() + //reply, err := tdx.GetHistoryMinuteTimeData(20220511, 0, "159607") + reply, err := tdx.GetHistoryMinuteTimeData(20220511, 0, "159607") if err != nil { t.Errorf("error:%s", err) } @@ -131,10 +126,10 @@ func Test_tdx_GetHistoryMinuteTimeData(t *testing.T) { func Test_tdx_GetTransactionData(t *testing.T) { fmt.Println("================ GetTransactionData ================") - api := newClient() - defer api.Disconnect() - //reply, err := api.GetHistoryMinuteTimeData(20220511, 0, "159607") - reply, err := api.GetTransactionData(MarketSh, "159607", 0, 10) + tdx := newClient() + defer tdx.Disconnect() + //reply, err := tdx.GetHistoryMinuteTimeData(20220511, 0, "159607") + reply, err := tdx.GetTransactionData(MarketSh, "159607", 0, 10) if err != nil { t.Errorf("error:%s", err) } @@ -146,9 +141,9 @@ func Test_tdx_GetTransactionData(t *testing.T) { func Test_tdx_GetHistoryTransactionData(t *testing.T) { fmt.Println("================ GetHistoryTransactionData ================") - api := newClient() - defer api.Disconnect() - reply, err := api.GetHistoryTransactionData(20230922, MarketSh, "159607", 0, 10) + tdx := newClient() + defer tdx.Disconnect() + reply, err := tdx.GetHistoryTransactionData(20230922, MarketSh, "159607", 0, 10) if err != nil { t.Errorf("error:%s", err) } diff --git a/constants.go b/constants.go index f750c4c..149e18d 100644 --- a/constants.go +++ b/constants.go @@ -23,10 +23,6 @@ const ( KLINE_TYPE_YEARLY = 11 // 年K 线 ) -const ( - DefaultRetryTimes = 3 // 重试次数 -) - var ( ErrBadData = errors.New("more than 8M data") ) diff --git a/examples/main.go b/examples/main.go index 34cb02e..ff93bc9 100644 --- a/examples/main.go +++ b/examples/main.go @@ -6,18 +6,14 @@ import ( ) func main() { - var opt = &gotdx.Opt{ - Host: "119.147.212.81", - Port: 7709, - } - api := gotdx.NewClient(opt) - connectReply, err := api.Connect() + tdx := gotdx.New(gotdx.WithTCPAddress("119.147.212.81:7709")) + _, err := tdx.Connect() if err != nil { - log.Println(err) + log.Fatalln(err) } - log.Println(connectReply.Info) + defer tdx.Disconnect() - reply, err := api.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"}) + reply, err := tdx.GetSecurityQuotes([]uint8{gotdx.MarketSh, gotdx.MarketSz}, []string{"000001", "600008"}) if err != nil { log.Println(err) } @@ -25,7 +21,4 @@ func main() { for _, obj := range reply.List { log.Printf("%+v", obj) } - - _ = api.Disconnect() - } diff --git a/options.go b/options.go new file mode 100644 index 0000000..b3b8454 --- /dev/null +++ b/options.go @@ -0,0 +1,34 @@ +package gotdx + +const ( + _defaultTCPAddress = "119.147.212.81:7709" + _defaultRetryTimes = 3 +) + +type Options struct { + TCPAddress string // 服务器地址 + MaxRetryTimes int // 重试次数 +} + +func defaultOptions() *Options { + return &Options{ + TCPAddress: _defaultTCPAddress, + MaxRetryTimes: _defaultRetryTimes, + } +} + +func applyOptions(opts ...Option) *Options { + o := defaultOptions() + for _, opt := range opts { + opt(o) + } + return o +} + +type Option func(options *Options) + +func WithTCPAddress(tcpAddress string) Option { + return func(o *Options) { + o.TCPAddress = tcpAddress + } +} diff --git a/proto/get_security_list.go b/proto/get_security_list.go index 18b5ce9..06d854d 100644 --- a/proto/get_security_list.go +++ b/proto/get_security_list.go @@ -3,7 +3,6 @@ package proto import ( "bytes" "encoding/binary" - "github.com/bensema/gotdx/util" ) type GetSecurityList struct { @@ -88,7 +87,7 @@ func (obj *GetSecurityList) UnSerialize(header interface{}, data []byte) error { binary.Read(bytes.NewBuffer(data[pos:pos+8]), binary.LittleEndian, &name) pos += 8 - ele.Name = util.Utf8ToGbk(name[:]) + ele.Name = Utf8ToGbk(name[:]) pos += 4 binary.Read(bytes.NewBuffer(data[pos:pos+1]), binary.LittleEndian, &ele.DecimalPoint) diff --git a/proto/get_security_quotes.go b/proto/get_security_quotes.go index 88d1236..6aa76ec 100644 --- a/proto/get_security_quotes.go +++ b/proto/get_security_quotes.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "encoding/hex" "fmt" - "github.com/bensema/gotdx/util" ) type GetSecurityQuotes struct { @@ -146,7 +145,7 @@ func (obj *GetSecurityQuotes) UnSerialize(header interface{}, data []byte) error binary.Read(bytes.NewBuffer(data[pos:pos+6]), binary.LittleEndian, &code) //enc := mahonia.NewDecoder("gbk") //ele.Code = enc.ConvertString(string(code[:])) - ele.Code = util.Utf8ToGbk(code[:]) + ele.Code = Utf8ToGbk(code[:]) pos += 6 binary.Read(bytes.NewBuffer(data[pos:pos+2]), binary.LittleEndian, &ele.Active1) pos += 2 diff --git a/proto/hello1.go b/proto/hello1.go index efdc5f7..3d05476 100644 --- a/proto/hello1.go +++ b/proto/hello1.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "github.com/bensema/gotdx/util" ) type Hello1 struct { @@ -60,7 +59,7 @@ func (obj *Hello1) Serialize() ([]byte, error) { func (obj *Hello1) UnSerialize(header interface{}, data []byte) error { obj.respHeader = header.(*RespHeader) - serverInfo := util.Utf8ToGbk(data[68:]) + serverInfo := Utf8ToGbk(data[68:]) obj.reply.Info = serverInfo return nil diff --git a/proto/hello2.go b/proto/hello2.go index 4429052..98e68e1 100644 --- a/proto/hello2.go +++ b/proto/hello2.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "github.com/bensema/gotdx/util" ) type Hello2 struct { @@ -57,7 +56,7 @@ func (obj *Hello2) Serialize() ([]byte, error) { func (obj *Hello2) UnSerialize(header interface{}, data []byte) error { obj.respHeader = header.(*RespHeader) - serverInfo := util.Utf8ToGbk(data[58:]) + serverInfo := Utf8ToGbk(data[58:]) //fmt.Println(hex.EncodeToString(data)) obj.reply.Info = serverInfo return nil diff --git a/util/string.go b/proto/util.go similarity index 96% rename from util/string.go rename to proto/util.go index cb0080d..a630ee1 100644 --- a/util/string.go +++ b/proto/util.go @@ -1,4 +1,4 @@ -package util +package proto import ( "bytes"