package rep import ( "encoding/binary" "encoding/json" "io" "git.crumpington.com/public/jldb/lib/errs" "git.crumpington.com/public/jldb/lib/httpconn" "git.crumpington.com/public/jldb/lib/wal" "net" "sync" "time" ) type client struct { // Mutex-protected variables. lock sync.Mutex closed bool conn net.Conn // The following are constant. endpoint string psk []byte timeout time.Duration buf []byte } func newClient(endpoint, psk string, timeout time.Duration) *client { b := make([]byte, 256) copy(b, []byte(psk)) return &client{ endpoint: endpoint, psk: b, timeout: timeout, } } func (c *client) GetInfo() (info Info, err error) { err = c.withConn(cmdGetInfo, func(conn net.Conn) error { return c.recvJSON(&info, conn, c.timeout) }) return info, err } func (c *client) RecvState(recv func(net.Conn) error) error { return c.withConn(cmdSendState, recv) } func (c *client) StreamWAL(w *wal.WAL) error { return c.withConn(cmdStreamWAL, func(conn net.Conn) error { return w.Recv(conn, c.timeout) }) } func (c *client) Close() { c.lock.Lock() defer c.lock.Unlock() c.closed = true if c.conn != nil { c.conn.Close() c.conn = nil } } // ---------------------------------------------------------------------------- func (c *client) writeCmd(cmd byte) error { c.conn.SetWriteDeadline(time.Now().Add(c.timeout)) if _, err := c.conn.Write([]byte{cmd}); err != nil { return errs.IO.WithErr(err) } return nil } func (c *client) dial() error { c.conn = nil conn, err := httpconn.Dial(c.endpoint) if err != nil { return err } conn.SetWriteDeadline(time.Now().Add(c.timeout)) if _, err := conn.Write(c.psk); err != nil { conn.Close() return errs.IO.WithErr(err) } c.conn = conn return nil } func (c *client) withConn(cmd byte, fn func(net.Conn) error) error { conn, err := c.getConn(cmd) if err != nil { return err } if err := fn(conn); err != nil { conn.Close() return err } return nil } func (c *client) getConn(cmd byte) (net.Conn, error) { c.lock.Lock() defer c.lock.Unlock() if c.closed { return nil, errs.IO.WithErr(io.EOF) } dialed := false if c.conn == nil { if err := c.dial(); err != nil { return nil, err } dialed = true } if err := c.writeCmd(cmd); err != nil { if dialed { c.conn = nil return nil, err } if err := c.dial(); err != nil { return nil, err } if err := c.writeCmd(cmd); err != nil { return nil, err } } return c.conn, nil } func (c *client) recvJSON( item any, conn net.Conn, timeout time.Duration, ) error { if cap(c.buf) < 2 { c.buf = make([]byte, 0, 1024) } buf := c.buf[:2] conn.SetReadDeadline(time.Now().Add(timeout)) if _, err := io.ReadFull(conn, buf); err != nil { return errs.IO.WithErr(err) } size := binary.LittleEndian.Uint16(buf) if cap(buf) < int(size) { buf = make([]byte, size) c.buf = buf } buf = buf[:size] if _, err := io.ReadFull(conn, buf); err != nil { return errs.IO.WithErr(err) } if err := json.Unmarshal(buf, item); err != nil { return errs.Unexpected.WithErr(err) } return nil }