diff --git a/README.md b/README.md index b9d291e..87e3072 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,8 @@ ## Roadmap +* Use probe and relayed-probe packets vs ping/pong. * Rename Mediator -> Relay -* Node: use symmetric encryption after handshake -* AEAD-AES uses a 12 byte nonce. We need to shrink the header: - * Remove Forward and replace it with a HeaderFlags bitfield. - * Forward, Asym/Sym, ... * Use default port 456 * Remove signing key from hub * Peer: UDP hole-punching diff --git a/node/main.go b/node/main.go index 35a00e6..c291e73 100644 --- a/node/main.go +++ b/node/main.go @@ -112,7 +112,7 @@ func main(netName, listenIP string, port uint16) { } go newHubPoller(netName, conf, peers).Run() - go readFromConn(conf.PeerIP, conn, peers) + go readFromConn(conn, peers) readFromIFace(iface, peers) } @@ -130,7 +130,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { // ---------------------------------------------------------------------------- -func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) { +func readFromConn(conn *net.UDPConn, peers remotePeers) { defer panicHandler() diff --git a/node/packets-util.go b/node/packets-util.go new file mode 100644 index 0000000..8a6e13a --- /dev/null +++ b/node/packets-util.go @@ -0,0 +1,165 @@ +package node + +import ( + "net/netip" + "sync/atomic" + "unsafe" + "vppn/fasttime" +) + +var ( + traceIDCounter uint64 +) + +func newTraceID() uint64 { + return uint64(fasttime.Now()<<30) + atomic.AddUint64(&traceIDCounter, 1) +} + +// ---------------------------------------------------------------------------- + +type binWriter struct { + b []byte + i int +} + +func newBinWriter(buf []byte) *binWriter { + buf = buf[:cap(buf)] + return &binWriter{buf, 0} +} + +func (w *binWriter) Bool(b bool) *binWriter { + if b { + return w.Byte(1) + } + return w.Byte(0) +} + +func (w *binWriter) Byte(b byte) *binWriter { + w.b[w.i] = b + w.i++ + return w +} + +func (w *binWriter) SharedKey(key [32]byte) *binWriter { + copy(w.b[w.i:w.i+32], key[:]) + w.i += 32 + return w +} + +func (w *binWriter) Uint16(x uint16) *binWriter { + *(*uint16)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 2 + return w +} + +func (w *binWriter) Uint64(x uint64) *binWriter { + *(*uint64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) Int64(x int64) *binWriter { + *(*int64)(unsafe.Pointer(&w.b[w.i])) = x + w.i += 8 + return w +} + +func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { + addr := addrPort.Addr().As16() + copy(w.b[w.i:w.i+16], addr[:]) + w.i += 16 + return w.Uint16(addrPort.Port()) +} + +func (w *binWriter) Build() []byte { + return w.b[:w.i] +} + +// ---------------------------------------------------------------------------- + +type binReader struct { + b []byte + i int + err error +} + +func newBinReader(buf []byte) *binReader { + return &binReader{b: buf} +} + +func (r *binReader) hasBytes(n int) bool { + if r.err != nil || (len(r.b)-r.i) < n { + r.err = errMalformedPacket + return false + } + return true +} + +func (r *binReader) Bool(b *bool) *binReader { + var bb byte + r.Byte(&bb) + *b = bb != 0 + return r +} + +func (r *binReader) Byte(b *byte) *binReader { + if !r.hasBytes(1) { + return r + } + *b = r.b[r.i] + r.i++ + return r +} + +func (r *binReader) SharedKey(x *[32]byte) *binReader { + if !r.hasBytes(32) { + return r + } + *x = ([32]byte)(r.b[r.i : r.i+32]) + r.i += 32 + return r +} + +func (r *binReader) Uint16(x *uint16) *binReader { + if !r.hasBytes(2) { + return r + } + *x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) + r.i += 2 + return r +} + +func (r *binReader) Uint64(x *uint64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) Int64(x *int64) *binReader { + if !r.hasBytes(8) { + return r + } + *x = *(*int64)(unsafe.Pointer(&r.b[r.i])) + r.i += 8 + return r +} + +func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { + if !r.hasBytes(18) { + return r + } + addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])) + addr = addr.Unmap() + r.i += 16 + var port uint16 + r.Uint16(&port) + *x = netip.AddrPortFrom(addr, port) + return r +} + +func (r *binReader) Error() error { + return r.err +} diff --git a/node/packets-util_test.go b/node/packets-util_test.go new file mode 100644 index 0000000..06b0370 --- /dev/null +++ b/node/packets-util_test.go @@ -0,0 +1,40 @@ +package node + +import ( + "net/netip" + "reflect" + "testing" +) + +func TestBinWriteRead(t *testing.T) { + buf := make([]byte, 1024) + + type Item struct { + Type byte + TraceID uint64 + DestAddr netip.AddrPort + } + + in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} + + buf = newBinWriter(buf). + Byte(in.Type). + Uint64(in.TraceID). + AddrPort(in.DestAddr). + Build() + + out := Item{} + + err := newBinReader(buf). + Byte(&out.Type). + Uint64(&out.TraceID). + AddrPort(&out.DestAddr). + Error() + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal(in, out) + } +} diff --git a/node/packets.go b/node/packets.go index 57c7341..bbc1262 100644 --- a/node/packets.go +++ b/node/packets.go @@ -13,8 +13,12 @@ var ( ) const ( - packetTypePing = iota + 1 + packetTypeSyn = iota + 1 + packetTypeSynAck + packetTypeAck + packetTypePing packetTypePong + packetTypeRelayed ) // ---------------------------------------------------------------------------- @@ -31,6 +35,8 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { p.Payload, err = parsePingPacket(buf) case packetTypePong: p.Payload, err = parsePongPacket(buf) + case packetTypeSyn: + p.Payload, err = parseSynPacket(buf) default: return errUnknownPacketType } @@ -39,34 +45,102 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { // ---------------------------------------------------------------------------- +type synPacket struct { + TraceID uint64 // TraceID to match response w/ request. + SharedKey [32]byte // Our shared key. + ServerAddr netip.AddrPort // The address we're sending to. + Direct bool // True if this is request isn't relayed. +} + +func (p synPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSyn). + Uint64(p.TraceID). + SharedKey(p.SharedKey). + AddrPort(p.ServerAddr). + Bool(p.Direct). + Build() +} + +func parseSynPacket(buf []byte) (p synPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + SharedKey(&p.SharedKey). + AddrPort(&p.ServerAddr). + Bool(&p.Direct). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type synAckPacket struct { + TraceID uint64 +} + +func (p synAckPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSynAck). + Uint64(p.TraceID). + Build() +} + +func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Error() + return +} + +// ---------------------------------------------------------------------------- + +type ackPacket struct { + TraceID uint64 +} + +func (p ackPacket) Marshal(buf []byte) []byte { + return newBinWriter(buf). + Byte(packetTypeSynAck). + Uint64(p.TraceID). + Build() +} + +func parseAckPacket(buf []byte) (p ackPacket, err error) { + err = newBinReader(buf[1:]). + Uint64(&p.TraceID). + Error() + return +} + +// ---------------------------------------------------------------------------- + // A pingPacket is sent from a node acting as a client, to a node acting // as a server. It always contains the shared key the client is expecting // to use for data encryption with the server. type pingPacket struct { - SentAt int64 // UnixMilli. + SentAt int64 // UnixMilli. // Not used. Use traceID. SharedKey [32]byte } func newPingPacket(sharedKey [32]byte) (pp pingPacket) { pp.SentAt = time.Now().UnixMilli() - copy(pp.SharedKey[:], sharedKey[:]) + pp.SharedKey = sharedKey return } func (p pingPacket) Marshal(buf []byte) []byte { - buf = buf[:41] - buf[0] = packetTypePing - *(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt) - copy(buf[9:41], p.SharedKey[:]) - return buf + return newBinWriter(buf). + Byte(packetTypePing). + Int64(p.SentAt). + SharedKey(p.SharedKey). + Build() } func parsePingPacket(buf []byte) (p pingPacket, err error) { - if len(buf) != 41 { - return p, errMalformedPacket - } - p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) - copy(p.SharedKey[:], buf[9:41]) + err = newBinReader(buf[1:]). + Int64(&p.SentAt). + SharedKey(&p.SharedKey). + Error() return } diff --git a/node/packets_test.go b/node/packets_test.go index da242d4..6d96ccb 100644 --- a/node/packets_test.go +++ b/node/packets_test.go @@ -2,10 +2,59 @@ package node import ( "crypto/rand" + "net/netip" "reflect" "testing" ) +func TestPacketSyn(t *testing.T) { + in := synPacket{ + TraceID: newTraceID(), + Direct: true, + ServerAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 34), + } + rand.Read(in.SharedKey[:]) + + out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize))) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal("\n", in, "\n", out) + } +} + +func TestPacketSynAck(t *testing.T) { + in := synAckPacket{ + TraceID: newTraceID(), + } + + out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal("\n", in, "\n", out) + } +} + +func TestPacketAck(t *testing.T) { + in := ackPacket{ + TraceID: newTraceID(), + } + + out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(in, out) { + t.Fatal("\n", in, "\n", out) + } +} + func TestPacketPing(t *testing.T) { sharedKey := make([]byte, 32) rand.Read(sharedKey) diff --git a/node/peer-states.go b/node/peer-states.go index 7a1de54..35ebc0b 100644 --- a/node/peer-states.go +++ b/node/peer-states.go @@ -12,7 +12,14 @@ import ( type peerState interface { Name() string + OnSyn(netip.AddrPort, synPacket) peerState + OnSynAck(netip.AddrPort, synAckPacket) peerState + OnAck(netip.AddrPort, ackPacket) peerState + + // When the peer is updated, we reset. Handled by base state. OnPeerUpdate(*m.Peer) peerState + + // To determe up / dataCipher. Handled by base state. OnPing(netip.AddrPort, pingPacket) peerState OnPong(netip.AddrPort, pongPacket) peerState OnPingTimer() peerState @@ -24,6 +31,7 @@ type peerState interface { type stateBase struct { // The purpose of this state machine is to manage this published data. published *atomic.Pointer[peerRoutingData] + staged peerRoutingData // Local copy of shared data. See publish(). // The other remote peers. peers *remotePeers @@ -39,9 +47,8 @@ type stateBase struct { counter *uint64 // Mutable peer data. - peer *m.Peer - remotePub bool - routingData peerRoutingData // Local copy of shared data. See publish(). + peer *m.Peer + remotePub bool // Timers pingTimer *time.Timer @@ -69,19 +76,19 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { s.peer = peer - s.routingData = peerRoutingData{} + s.staged = peerRoutingData{} + defer s.publish() if peer == nil { return newStateNoPeer(s) } - - s.routingData.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) + s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) ip, isValid := netip.AddrFromSlice(peer.PublicIP) if isValid { s.remotePub = true - s.routingData.remoteAddr = netip.AddrPortFrom(ip, peer.Port) - s.routingData.relay = peer.Mediator + s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) + s.staged.relay = peer.Mediator if s.localPub && s.localIP < s.remoteIP { return newStateServer(s) @@ -96,10 +103,16 @@ func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { return newStateSelectRelay(s) } -func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil } -func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil } -func (s *stateBase) OnPingTimer() peerState { return nil } -func (s *stateBase) OnTimeoutTimer() peerState { return nil } +func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState { return nil } +func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil } +func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState { return nil } +func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil } +func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil } +func (s *stateBase) OnPingTimer() peerState { return nil } + +func (s *stateBase) OnTimeoutTimer() peerState { + return s.selectStateFromPeer(s.peer) +} // Helpers. @@ -113,7 +126,7 @@ func (s *stateBase) logf(msg string, args ...any) { } func (s *stateBase) publish() { - data := s.routingData + data := s.staged s.published.Store(&data) } @@ -148,11 +161,11 @@ func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { DestIP: s.remoteIP, } - buf = s.routingData.controlCipher.Encrypt(h, buf, s.encBuf) - if s.routingData.relayIP != 0 { - s.peers[s.routingData.relayIP].RelayFor(s.remoteIP, buf) + buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf) + if s.staged.relayIP != 0 { + s.peers[s.staged.relayIP].RelayFor(s.remoteIP, buf) } else { - s.conn.WriteTo(buf, s.routingData.remoteAddr) + s.conn.WriteTo(buf, s.staged.remoteAddr) } } @@ -162,6 +175,8 @@ type stateNoPeer struct{ *stateBase } func newStateNoPeer(b *stateBase) *stateNoPeer { s := &stateNoPeer{b} + s.pingTimer.Stop() + s.timeoutTimer.Stop() s.publish() return s } @@ -177,8 +192,8 @@ func newStateClient(b *stateBase) peerState { s := &stateClient{stateBase: b} s.publish() - s.routingData.dataCipher = newDataCipher() - s.sharedKey = s.routingData.dataCipher.Key() + s.staged.dataCipher = newDataCipher() + s.sharedKey = s.staged.dataCipher.Key() s.sendPing(s.sharedKey) s.resetPingTimer() @@ -189,8 +204,8 @@ func newStateClient(b *stateBase) peerState { func (s *stateClient) Name() string { return "client" } func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState { - if !s.routingData.up { - s.routingData.up = true + if !s.staged.up { + s.staged.up = true s.publish() } s.resetTimeoutTimer() @@ -204,7 +219,7 @@ func (s *stateClient) OnPingTimer() peerState { } func (s *stateClient) OnTimeoutTimer() peerState { - s.routingData.up = false + s.staged.up = false s.publish() return nil } @@ -226,17 +241,17 @@ func newStateServer(b *stateBase) peerState { func (s *stateServer) Name() string { return "server" } func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState { - if addr != s.routingData.remoteAddr { + if addr != s.staged.remoteAddr { s.logf("Got new peer address: %v", addr) - s.routingData.remoteAddr = addr - s.routingData.up = true + s.staged.remoteAddr = addr + s.staged.up = true s.publish() } - if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() { + if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() { s.logf("Got new shared key.") - s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey) - s.routingData.up = true + s.staged.dataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.up = true s.publish() } @@ -252,13 +267,13 @@ type stateSelectRelay struct { func newStateSelectRelay(b *stateBase) peerState { s := &stateSelectRelay{stateBase: b} - s.routingData.dataCipher = nil - s.routingData.up = false + s.staged.dataCipher = nil + s.staged.up = false s.publish() if relay := s.selectRelay(); relay != 0 { - s.routingData.up = false - s.routingData.relayIP = relay + s.staged.up = false + s.staged.relayIP = relay return s.selectRole() } @@ -278,7 +293,8 @@ func (s *stateSelectRelay) Name() string { return "select-relay" } func (s *stateSelectRelay) OnPingTimer() peerState { if relay := s.selectRelay(); relay != 0 { - s.routingData.relayIP = relay + s.logf("Got relay IP: %d", relay) + s.staged.relayIP = relay return s.selectRole() } s.resetPingTimer() @@ -295,8 +311,8 @@ type stateClientRelayed struct { func newStateClientRelayed(b *stateBase) peerState { s := &stateClientRelayed{stateBase: b} - s.routingData.dataCipher = newDataCipher() - s.sharedKey = s.routingData.dataCipher.Key() + s.staged.dataCipher = newDataCipher() + s.sharedKey = s.staged.dataCipher.Key() s.publish() s.sendPing(s.sharedKey) @@ -308,10 +324,11 @@ func newStateClientRelayed(b *stateBase) peerState { func (s *stateClientRelayed) Name() string { return "client-relayed" } func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState { - if !s.routingData.up { - s.routingData.up = true + if !s.staged.up { + s.staged.up = true s.publish() } + s.resetTimeoutTimer() return nil } @@ -342,10 +359,10 @@ func newStateServerRelayed(b *stateBase) peerState { func (s *stateServerRelayed) Name() string { return "server-relayed" } func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState { - if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() { + if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() { s.logf("Got new shared key.") - s.routingData.up = true - s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey) + s.staged.up = true + s.staged.dataCipher = newDataCipherFromKey(p.SharedKey) s.publish() } diff --git a/node/peer-supervisor.go b/node/peer-supervisor.go index ac2508e..08691aa 100644 --- a/node/peer-supervisor.go +++ b/node/peer-supervisor.go @@ -46,6 +46,12 @@ func (rp *remotePeer) supervise(conf m.PeerConfig) { case pkt := <-rp.controlPackets: switch p := pkt.Payload.(type) { + case synPacket: + nextState = curState.OnSyn(pkt.RemoteAddr, p) + case synAckPacket: + nextState = curState.OnSynAck(pkt.RemoteAddr, p) + case ackPacket: + nextState = curState.OnAck(pkt.RemoteAddr, p) case pingPacket: nextState = curState.OnPing(pkt.RemoteAddr, p) case pongPacket: