diff --git a/client.go b/client.go index dbcf35b..6d426ae 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package tdx import ( + "fmt" "github.com/injoyai/base/maps" "github.com/injoyai/base/maps/wait/v2" "github.com/injoyai/conv" @@ -12,15 +13,30 @@ import ( "time" ) +// WithDebug 是否打印通讯数据 +func WithDebug(b ...bool) client.Option { + return func(c *client.Client) { + c.Logger.Debug(b...) + } +} + +// WithRedial 断线重连 +func WithRedial(b ...bool) client.Option { + return func(c *client.Client) { + c.SetRedial(b...) + } +} + // Dial 与服务器建立连接 func Dial(addr string, op ...client.Option) (cli *Client, err error) { cli = &Client{ - w: wait.New(time.Second * 2), - m: maps.NewSafe(), + Wait: wait.New(time.Second * 2), + m: maps.NewSafe(), } - cli.c, err = dial.TCP(addr, func(c *client.Client) { + cli.Client, err = dial.TCP(addr, func(c *client.Client) { + c.Logger.Debug(true) //开启日志打印 c.Logger.WithHEX() //以HEX显示 c.SetOption(op...) //自定义选项 c.Event.OnReadFrom = protocol.ReadFrom //分包 @@ -31,32 +47,26 @@ func Dial(addr string, op ...client.Option) (cli *Client, err error) { _, err := w.Write(bs) return err }) + + f := protocol.MConnect.Frame() + if _, err = c.Write(f.Bytes()); err != nil { + c.Close() + } }) if err != nil { return nil, err } - go cli.c.Run() + go cli.Client.Run() - err = cli.connect() - if err != nil { - cli.c.Close() - return nil, err - } - - return cli, err + return cli, nil } type Client struct { - c *client.Client - w *wait.Entity - m *maps.Safe - msgID uint32 -} - -// Done 连接关闭 -func (this *Client) Done() <-chan struct{} { - return this.c.Done() + *client.Client //客户端实例 + Wait *wait.Entity //异步回调,设置超时时间,超时则返回错误 + m *maps.Safe //有部分解析需要用到代码,返回数据获取不到,固请求的时候缓存下 + msgID uint32 //消息id,使用SendFrame自动累加 } // handlerDealMessage 处理服务器响应的数据 @@ -76,6 +86,8 @@ func (this *Client) handlerDealMessage(c *client.Client, msg ios.Acker) { case protocol.TypeConnect: + case protocol.TypeHeart: + case protocol.TypeStockCount: resp, err = protocol.MStockCount.Decode(f.Data) @@ -94,6 +106,9 @@ func (this *Client) handlerDealMessage(c *client.Client, msg ios.Acker) { case protocol.TypeStockHistoryMinuteTrade: resp, err = protocol.MStockHistoryMinuteTrade.Decode(f.Data, code) + default: + err = fmt.Errorf("通讯类型未解析:0x%X", f.Type) + } if err != nil { @@ -101,26 +116,17 @@ func (this *Client) handlerDealMessage(c *client.Client, msg ios.Acker) { return } - this.w.Done(conv.String(f.MsgID), resp) + this.Wait.Done(conv.String(f.MsgID), resp) } func (this *Client) SendFrame(f *protocol.Frame) (any, error) { this.msgID++ f.MsgID = this.msgID - if _, err := this.c.Write(f.Bytes()); err != nil { + if _, err := this.Client.Write(f.Bytes()); err != nil { return nil, err } - return this.w.Wait(conv.String(this.msgID)) -} - -// Write 实现io.Writer,向服务器写入数据 -func (this *Client) Write(bs []byte) (int, error) { - return this.c.Write(bs) -} - -func (this *Client) Close() error { - return this.c.Close() + return this.Wait.Wait(conv.String(this.msgID)) } func (this *Client) connect() error { diff --git a/example/GetStockHistoryMinuteTrade/main.go b/example/GetStockHistoryMinuteTrade/main.go index 1cab4d1..e2fe06a 100644 --- a/example/GetStockHistoryMinuteTrade/main.go +++ b/example/GetStockHistoryMinuteTrade/main.go @@ -10,7 +10,7 @@ import ( func main() { common.Test(func(c *tdx.Client) { - t := time.Date(2024, 10, 29, 0, 0, 0, 0, time.Local) + t := time.Date(2024, 10, 28, 0, 0, 0, 0, time.Local) resp, err := c.GetStockHistoryMinuteTrade(t, protocol.ExchangeSH, "000001", 0, 2000) logs.PanicErr(err)