wip: working, modifying logic to allow local discovery and hole punching in the future.
This commit is contained in:
		| @@ -2,11 +2,8 @@ | |||||||
|  |  | ||||||
| ## Roadmap | ## Roadmap | ||||||
|  |  | ||||||
|  | * Use probe and relayed-probe packets vs ping/pong. | ||||||
| * Rename Mediator -> Relay | * Rename Mediator -> Relay | ||||||
| * Node: use symmetric encryption after handshake |  | ||||||
| * AEAD-AES uses a 12 byte nonce. We need to shrink the header: |  | ||||||
|   * Remove Forward and replace it with a HeaderFlags bitfield. |  | ||||||
|     * Forward, Asym/Sym, ... |  | ||||||
| * Use default port 456 | * Use default port 456 | ||||||
| * Remove signing key from hub | * Remove signing key from hub | ||||||
| * Peer: UDP hole-punching | * Peer: UDP hole-punching | ||||||
|   | |||||||
| @@ -112,7 +112,7 @@ func main(netName, listenIP string, port uint16) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	go newHubPoller(netName, conf, peers).Run() | 	go newHubPoller(netName, conf, peers).Run() | ||||||
| 	go readFromConn(conf.PeerIP, conn, peers) | 	go readFromConn(conn, peers) | ||||||
| 	readFromIFace(iface, peers) | 	readFromIFace(iface, peers) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -130,7 +130,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) { | func readFromConn(conn *net.UDPConn, peers remotePeers) { | ||||||
|  |  | ||||||
| 	defer panicHandler() | 	defer panicHandler() | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										165
									
								
								node/packets-util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								node/packets-util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,165 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"unsafe" | ||||||
|  | 	"vppn/fasttime" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	traceIDCounter uint64 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func newTraceID() uint64 { | ||||||
|  | 	return uint64(fasttime.Now()<<30) + atomic.AddUint64(&traceIDCounter, 1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type binWriter struct { | ||||||
|  | 	b []byte | ||||||
|  | 	i int | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newBinWriter(buf []byte) *binWriter { | ||||||
|  | 	buf = buf[:cap(buf)] | ||||||
|  | 	return &binWriter{buf, 0} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Bool(b bool) *binWriter { | ||||||
|  | 	if b { | ||||||
|  | 		return w.Byte(1) | ||||||
|  | 	} | ||||||
|  | 	return w.Byte(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Byte(b byte) *binWriter { | ||||||
|  | 	w.b[w.i] = b | ||||||
|  | 	w.i++ | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) SharedKey(key [32]byte) *binWriter { | ||||||
|  | 	copy(w.b[w.i:w.i+32], key[:]) | ||||||
|  | 	w.i += 32 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Uint16(x uint16) *binWriter { | ||||||
|  | 	*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x | ||||||
|  | 	w.i += 2 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Uint64(x uint64) *binWriter { | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x | ||||||
|  | 	w.i += 8 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Int64(x int64) *binWriter { | ||||||
|  | 	*(*int64)(unsafe.Pointer(&w.b[w.i])) = x | ||||||
|  | 	w.i += 8 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | ||||||
|  | 	addr := addrPort.Addr().As16() | ||||||
|  | 	copy(w.b[w.i:w.i+16], addr[:]) | ||||||
|  | 	w.i += 16 | ||||||
|  | 	return w.Uint16(addrPort.Port()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Build() []byte { | ||||||
|  | 	return w.b[:w.i] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type binReader struct { | ||||||
|  | 	b   []byte | ||||||
|  | 	i   int | ||||||
|  | 	err error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newBinReader(buf []byte) *binReader { | ||||||
|  | 	return &binReader{b: buf} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) hasBytes(n int) bool { | ||||||
|  | 	if r.err != nil || (len(r.b)-r.i) < n { | ||||||
|  | 		r.err = errMalformedPacket | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Bool(b *bool) *binReader { | ||||||
|  | 	var bb byte | ||||||
|  | 	r.Byte(&bb) | ||||||
|  | 	*b = bb != 0 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Byte(b *byte) *binReader { | ||||||
|  | 	if !r.hasBytes(1) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*b = r.b[r.i] | ||||||
|  | 	r.i++ | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) SharedKey(x *[32]byte) *binReader { | ||||||
|  | 	if !r.hasBytes(32) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = ([32]byte)(r.b[r.i : r.i+32]) | ||||||
|  | 	r.i += 32 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Uint16(x *uint16) *binReader { | ||||||
|  | 	if !r.hasBytes(2) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) | ||||||
|  | 	r.i += 2 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Uint64(x *uint64) *binReader { | ||||||
|  | 	if !r.hasBytes(8) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) | ||||||
|  | 	r.i += 8 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Int64(x *int64) *binReader { | ||||||
|  | 	if !r.hasBytes(8) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = *(*int64)(unsafe.Pointer(&r.b[r.i])) | ||||||
|  | 	r.i += 8 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | ||||||
|  | 	if !r.hasBytes(18) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])) | ||||||
|  | 	addr = addr.Unmap() | ||||||
|  | 	r.i += 16 | ||||||
|  | 	var port uint16 | ||||||
|  | 	r.Uint16(&port) | ||||||
|  | 	*x = netip.AddrPortFrom(addr, port) | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Error() error { | ||||||
|  | 	return r.err | ||||||
|  | } | ||||||
							
								
								
									
										40
									
								
								node/packets-util_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								node/packets-util_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | package node | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestBinWriteRead(t *testing.T) { | ||||||
|  | 	buf := make([]byte, 1024) | ||||||
|  |  | ||||||
|  | 	type Item struct { | ||||||
|  | 		Type     byte | ||||||
|  | 		TraceID  uint64 | ||||||
|  | 		DestAddr netip.AddrPort | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	in := Item{1, 2, netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22)} | ||||||
|  |  | ||||||
|  | 	buf = newBinWriter(buf). | ||||||
|  | 		Byte(in.Type). | ||||||
|  | 		Uint64(in.TraceID). | ||||||
|  | 		AddrPort(in.DestAddr). | ||||||
|  | 		Build() | ||||||
|  |  | ||||||
|  | 	out := Item{} | ||||||
|  |  | ||||||
|  | 	err := newBinReader(buf). | ||||||
|  | 		Byte(&out.Type). | ||||||
|  | 		Uint64(&out.TraceID). | ||||||
|  | 		AddrPort(&out.DestAddr). | ||||||
|  | 		Error() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(in, out) { | ||||||
|  | 		t.Fatal(in, out) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										100
									
								
								node/packets.go
									
									
									
									
									
								
							
							
						
						
									
										100
									
								
								node/packets.go
									
									
									
									
									
								
							| @@ -13,8 +13,12 @@ var ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	packetTypePing = iota + 1 | 	packetTypeSyn = iota + 1 | ||||||
|  | 	packetTypeSynAck | ||||||
|  | 	packetTypeAck | ||||||
|  | 	packetTypePing | ||||||
| 	packetTypePong | 	packetTypePong | ||||||
|  | 	packetTypeRelayed | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
| @@ -31,6 +35,8 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { | |||||||
| 		p.Payload, err = parsePingPacket(buf) | 		p.Payload, err = parsePingPacket(buf) | ||||||
| 	case packetTypePong: | 	case packetTypePong: | ||||||
| 		p.Payload, err = parsePongPacket(buf) | 		p.Payload, err = parsePongPacket(buf) | ||||||
|  | 	case packetTypeSyn: | ||||||
|  | 		p.Payload, err = parseSynPacket(buf) | ||||||
| 	default: | 	default: | ||||||
| 		return errUnknownPacketType | 		return errUnknownPacketType | ||||||
| 	} | 	} | ||||||
| @@ -39,34 +45,102 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type synPacket struct { | ||||||
|  | 	TraceID    uint64         // TraceID to match response w/ request. | ||||||
|  | 	SharedKey  [32]byte       // Our shared key. | ||||||
|  | 	ServerAddr netip.AddrPort // The address we're sending to. | ||||||
|  | 	Direct     bool           // True if this is request isn't relayed. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p synPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeSyn). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		SharedKey(p.SharedKey). | ||||||
|  | 		AddrPort(p.ServerAddr). | ||||||
|  | 		Bool(p.Direct). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		SharedKey(&p.SharedKey). | ||||||
|  | 		AddrPort(&p.ServerAddr). | ||||||
|  | 		Bool(&p.Direct). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type synAckPacket struct { | ||||||
|  | 	TraceID uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p synAckPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeSynAck). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseSynAckPacket(buf []byte) (p synAckPacket, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type ackPacket struct { | ||||||
|  | 	TraceID uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p ackPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeSynAck). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| // A pingPacket is sent from a node acting as a client, to a node acting | // A pingPacket is sent from a node acting as a client, to a node acting | ||||||
| // as a server. It always contains the shared key the client is expecting | // as a server. It always contains the shared key the client is expecting | ||||||
| // to use for data encryption with the server. | // to use for data encryption with the server. | ||||||
| type pingPacket struct { | type pingPacket struct { | ||||||
| 	SentAt    int64 // UnixMilli. | 	SentAt    int64 // UnixMilli. // Not used. Use traceID. | ||||||
| 	SharedKey [32]byte | 	SharedKey [32]byte | ||||||
| } | } | ||||||
|  |  | ||||||
| func newPingPacket(sharedKey [32]byte) (pp pingPacket) { | func newPingPacket(sharedKey [32]byte) (pp pingPacket) { | ||||||
| 	pp.SentAt = time.Now().UnixMilli() | 	pp.SentAt = time.Now().UnixMilli() | ||||||
| 	copy(pp.SharedKey[:], sharedKey[:]) | 	pp.SharedKey = sharedKey | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p pingPacket) Marshal(buf []byte) []byte { | func (p pingPacket) Marshal(buf []byte) []byte { | ||||||
| 	buf = buf[:41] | 	return newBinWriter(buf). | ||||||
| 	buf[0] = packetTypePing | 		Byte(packetTypePing). | ||||||
| 	*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt) | 		Int64(p.SentAt). | ||||||
| 	copy(buf[9:41], p.SharedKey[:]) | 		SharedKey(p.SharedKey). | ||||||
| 	return buf | 		Build() | ||||||
| } | } | ||||||
|  |  | ||||||
| func parsePingPacket(buf []byte) (p pingPacket, err error) { | func parsePingPacket(buf []byte) (p pingPacket, err error) { | ||||||
| 	if len(buf) != 41 { | 	err = newBinReader(buf[1:]). | ||||||
| 		return p, errMalformedPacket | 		Int64(&p.SentAt). | ||||||
| 	} | 		SharedKey(&p.SharedKey). | ||||||
| 	p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) | 		Error() | ||||||
| 	copy(p.SharedKey[:], buf[9:41]) |  | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,10 +2,59 @@ package node | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
|  | 	"net/netip" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func TestPacketSyn(t *testing.T) { | ||||||
|  | 	in := synPacket{ | ||||||
|  | 		TraceID:    newTraceID(), | ||||||
|  | 		Direct:     true, | ||||||
|  | 		ServerAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 34), | ||||||
|  | 	} | ||||||
|  | 	rand.Read(in.SharedKey[:]) | ||||||
|  |  | ||||||
|  | 	out, err := parseSynPacket(in.Marshal(make([]byte, bufferSize))) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(in, out) { | ||||||
|  | 		t.Fatal("\n", in, "\n", out) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPacketSynAck(t *testing.T) { | ||||||
|  | 	in := synAckPacket{ | ||||||
|  | 		TraceID: newTraceID(), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	out, err := parseSynAckPacket(in.Marshal(make([]byte, bufferSize))) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(in, out) { | ||||||
|  | 		t.Fatal("\n", in, "\n", out) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPacketAck(t *testing.T) { | ||||||
|  | 	in := ackPacket{ | ||||||
|  | 		TraceID: newTraceID(), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	out, err := parseAckPacket(in.Marshal(make([]byte, bufferSize))) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(in, out) { | ||||||
|  | 		t.Fatal("\n", in, "\n", out) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestPacketPing(t *testing.T) { | func TestPacketPing(t *testing.T) { | ||||||
| 	sharedKey := make([]byte, 32) | 	sharedKey := make([]byte, 32) | ||||||
| 	rand.Read(sharedKey) | 	rand.Read(sharedKey) | ||||||
|   | |||||||
| @@ -12,7 +12,14 @@ import ( | |||||||
|  |  | ||||||
| type peerState interface { | type peerState interface { | ||||||
| 	Name() string | 	Name() string | ||||||
|  | 	OnSyn(netip.AddrPort, synPacket) peerState | ||||||
|  | 	OnSynAck(netip.AddrPort, synAckPacket) peerState | ||||||
|  | 	OnAck(netip.AddrPort, ackPacket) peerState | ||||||
|  |  | ||||||
|  | 	// When the peer is updated, we reset. Handled by base state. | ||||||
| 	OnPeerUpdate(*m.Peer) peerState | 	OnPeerUpdate(*m.Peer) peerState | ||||||
|  |  | ||||||
|  | 	// To determe up / dataCipher. Handled by base state. | ||||||
| 	OnPing(netip.AddrPort, pingPacket) peerState | 	OnPing(netip.AddrPort, pingPacket) peerState | ||||||
| 	OnPong(netip.AddrPort, pongPacket) peerState | 	OnPong(netip.AddrPort, pongPacket) peerState | ||||||
| 	OnPingTimer() peerState | 	OnPingTimer() peerState | ||||||
| @@ -24,6 +31,7 @@ type peerState interface { | |||||||
| type stateBase struct { | type stateBase struct { | ||||||
| 	// The purpose of this state machine is to manage this published data. | 	// The purpose of this state machine is to manage this published data. | ||||||
| 	published *atomic.Pointer[peerRoutingData] | 	published *atomic.Pointer[peerRoutingData] | ||||||
|  | 	staged    peerRoutingData // Local copy of shared data. See publish(). | ||||||
|  |  | ||||||
| 	// The other remote peers. | 	// The other remote peers. | ||||||
| 	peers *remotePeers | 	peers *remotePeers | ||||||
| @@ -41,7 +49,6 @@ type stateBase struct { | |||||||
| 	// Mutable peer data. | 	// Mutable peer data. | ||||||
| 	peer      *m.Peer | 	peer      *m.Peer | ||||||
| 	remotePub bool | 	remotePub bool | ||||||
| 	routingData peerRoutingData // Local copy of shared data. See publish(). |  | ||||||
|  |  | ||||||
| 	// Timers | 	// Timers | ||||||
| 	pingTimer    *time.Timer | 	pingTimer    *time.Timer | ||||||
| @@ -69,19 +76,19 @@ func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { | |||||||
|  |  | ||||||
| func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { | func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { | ||||||
| 	s.peer = peer | 	s.peer = peer | ||||||
| 	s.routingData = peerRoutingData{} | 	s.staged = peerRoutingData{} | ||||||
|  | 	defer s.publish() | ||||||
|  |  | ||||||
| 	if peer == nil { | 	if peer == nil { | ||||||
| 		return newStateNoPeer(s) | 		return newStateNoPeer(s) | ||||||
| 	} | 	} | ||||||
|  | 	s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) | ||||||
| 	s.routingData.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) |  | ||||||
|  |  | ||||||
| 	ip, isValid := netip.AddrFromSlice(peer.PublicIP) | 	ip, isValid := netip.AddrFromSlice(peer.PublicIP) | ||||||
| 	if isValid { | 	if isValid { | ||||||
| 		s.remotePub = true | 		s.remotePub = true | ||||||
| 		s.routingData.remoteAddr = netip.AddrPortFrom(ip, peer.Port) | 		s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) | ||||||
| 		s.routingData.relay = peer.Mediator | 		s.staged.relay = peer.Mediator | ||||||
|  |  | ||||||
| 		if s.localPub && s.localIP < s.remoteIP { | 		if s.localPub && s.localIP < s.remoteIP { | ||||||
| 			return newStateServer(s) | 			return newStateServer(s) | ||||||
| @@ -96,10 +103,16 @@ func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { | |||||||
| 	return newStateSelectRelay(s) | 	return newStateSelectRelay(s) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState       { return nil } | ||||||
|  | func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil } | ||||||
|  | func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState       { return nil } | ||||||
| func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState     { return nil } | func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState     { return nil } | ||||||
| func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState     { return nil } | func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState     { return nil } | ||||||
| func (s *stateBase) OnPingTimer() peerState                                  { return nil } | func (s *stateBase) OnPingTimer() peerState                                  { return nil } | ||||||
| func (s *stateBase) OnTimeoutTimer() peerState                           { return nil } |  | ||||||
|  | func (s *stateBase) OnTimeoutTimer() peerState { | ||||||
|  | 	return s.selectStateFromPeer(s.peer) | ||||||
|  | } | ||||||
|  |  | ||||||
| // Helpers. | // Helpers. | ||||||
|  |  | ||||||
| @@ -113,7 +126,7 @@ func (s *stateBase) logf(msg string, args ...any) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateBase) publish() { | func (s *stateBase) publish() { | ||||||
| 	data := s.routingData | 	data := s.staged | ||||||
| 	s.published.Store(&data) | 	s.published.Store(&data) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -148,11 +161,11 @@ func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | |||||||
| 		DestIP:   s.remoteIP, | 		DestIP:   s.remoteIP, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	buf = s.routingData.controlCipher.Encrypt(h, buf, s.encBuf) | 	buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf) | ||||||
| 	if s.routingData.relayIP != 0 { | 	if s.staged.relayIP != 0 { | ||||||
| 		s.peers[s.routingData.relayIP].RelayFor(s.remoteIP, buf) | 		s.peers[s.staged.relayIP].RelayFor(s.remoteIP, buf) | ||||||
| 	} else { | 	} else { | ||||||
| 		s.conn.WriteTo(buf, s.routingData.remoteAddr) | 		s.conn.WriteTo(buf, s.staged.remoteAddr) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -162,6 +175,8 @@ type stateNoPeer struct{ *stateBase } | |||||||
|  |  | ||||||
| func newStateNoPeer(b *stateBase) *stateNoPeer { | func newStateNoPeer(b *stateBase) *stateNoPeer { | ||||||
| 	s := &stateNoPeer{b} | 	s := &stateNoPeer{b} | ||||||
|  | 	s.pingTimer.Stop() | ||||||
|  | 	s.timeoutTimer.Stop() | ||||||
| 	s.publish() | 	s.publish() | ||||||
| 	return s | 	return s | ||||||
| } | } | ||||||
| @@ -177,8 +192,8 @@ func newStateClient(b *stateBase) peerState { | |||||||
| 	s := &stateClient{stateBase: b} | 	s := &stateClient{stateBase: b} | ||||||
| 	s.publish() | 	s.publish() | ||||||
|  |  | ||||||
| 	s.routingData.dataCipher = newDataCipher() | 	s.staged.dataCipher = newDataCipher() | ||||||
| 	s.sharedKey = s.routingData.dataCipher.Key() | 	s.sharedKey = s.staged.dataCipher.Key() | ||||||
|  |  | ||||||
| 	s.sendPing(s.sharedKey) | 	s.sendPing(s.sharedKey) | ||||||
| 	s.resetPingTimer() | 	s.resetPingTimer() | ||||||
| @@ -189,8 +204,8 @@ func newStateClient(b *stateBase) peerState { | |||||||
| func (s *stateClient) Name() string { return "client" } | func (s *stateClient) Name() string { return "client" } | ||||||
|  |  | ||||||
| func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState { | func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState { | ||||||
| 	if !s.routingData.up { | 	if !s.staged.up { | ||||||
| 		s.routingData.up = true | 		s.staged.up = true | ||||||
| 		s.publish() | 		s.publish() | ||||||
| 	} | 	} | ||||||
| 	s.resetTimeoutTimer() | 	s.resetTimeoutTimer() | ||||||
| @@ -204,7 +219,7 @@ func (s *stateClient) OnPingTimer() peerState { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *stateClient) OnTimeoutTimer() peerState { | func (s *stateClient) OnTimeoutTimer() peerState { | ||||||
| 	s.routingData.up = false | 	s.staged.up = false | ||||||
| 	s.publish() | 	s.publish() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -226,17 +241,17 @@ func newStateServer(b *stateBase) peerState { | |||||||
| func (s *stateServer) Name() string { return "server" } | func (s *stateServer) Name() string { return "server" } | ||||||
|  |  | ||||||
| func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState { | func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState { | ||||||
| 	if addr != s.routingData.remoteAddr { | 	if addr != s.staged.remoteAddr { | ||||||
| 		s.logf("Got new peer address: %v", addr) | 		s.logf("Got new peer address: %v", addr) | ||||||
| 		s.routingData.remoteAddr = addr | 		s.staged.remoteAddr = addr | ||||||
| 		s.routingData.up = true | 		s.staged.up = true | ||||||
| 		s.publish() | 		s.publish() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() { | 	if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() { | ||||||
| 		s.logf("Got new shared key.") | 		s.logf("Got new shared key.") | ||||||
| 		s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey) | 		s.staged.dataCipher = newDataCipherFromKey(p.SharedKey) | ||||||
| 		s.routingData.up = true | 		s.staged.up = true | ||||||
| 		s.publish() | 		s.publish() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -252,13 +267,13 @@ type stateSelectRelay struct { | |||||||
|  |  | ||||||
| func newStateSelectRelay(b *stateBase) peerState { | func newStateSelectRelay(b *stateBase) peerState { | ||||||
| 	s := &stateSelectRelay{stateBase: b} | 	s := &stateSelectRelay{stateBase: b} | ||||||
| 	s.routingData.dataCipher = nil | 	s.staged.dataCipher = nil | ||||||
| 	s.routingData.up = false | 	s.staged.up = false | ||||||
| 	s.publish() | 	s.publish() | ||||||
|  |  | ||||||
| 	if relay := s.selectRelay(); relay != 0 { | 	if relay := s.selectRelay(); relay != 0 { | ||||||
| 		s.routingData.up = false | 		s.staged.up = false | ||||||
| 		s.routingData.relayIP = relay | 		s.staged.relayIP = relay | ||||||
| 		return s.selectRole() | 		return s.selectRole() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -278,7 +293,8 @@ func (s *stateSelectRelay) Name() string { return "select-relay" } | |||||||
|  |  | ||||||
| func (s *stateSelectRelay) OnPingTimer() peerState { | func (s *stateSelectRelay) OnPingTimer() peerState { | ||||||
| 	if relay := s.selectRelay(); relay != 0 { | 	if relay := s.selectRelay(); relay != 0 { | ||||||
| 		s.routingData.relayIP = relay | 		s.logf("Got relay IP: %d", relay) | ||||||
|  | 		s.staged.relayIP = relay | ||||||
| 		return s.selectRole() | 		return s.selectRole() | ||||||
| 	} | 	} | ||||||
| 	s.resetPingTimer() | 	s.resetPingTimer() | ||||||
| @@ -295,8 +311,8 @@ type stateClientRelayed struct { | |||||||
| func newStateClientRelayed(b *stateBase) peerState { | func newStateClientRelayed(b *stateBase) peerState { | ||||||
| 	s := &stateClientRelayed{stateBase: b} | 	s := &stateClientRelayed{stateBase: b} | ||||||
|  |  | ||||||
| 	s.routingData.dataCipher = newDataCipher() | 	s.staged.dataCipher = newDataCipher() | ||||||
| 	s.sharedKey = s.routingData.dataCipher.Key() | 	s.sharedKey = s.staged.dataCipher.Key() | ||||||
| 	s.publish() | 	s.publish() | ||||||
|  |  | ||||||
| 	s.sendPing(s.sharedKey) | 	s.sendPing(s.sharedKey) | ||||||
| @@ -308,10 +324,11 @@ func newStateClientRelayed(b *stateBase) peerState { | |||||||
| func (s *stateClientRelayed) Name() string { return "client-relayed" } | func (s *stateClientRelayed) Name() string { return "client-relayed" } | ||||||
|  |  | ||||||
| func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState { | func (s *stateClientRelayed) OnPong(addr netip.AddrPort, p pongPacket) peerState { | ||||||
| 	if !s.routingData.up { | 	if !s.staged.up { | ||||||
| 		s.routingData.up = true | 		s.staged.up = true | ||||||
| 		s.publish() | 		s.publish() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	s.resetTimeoutTimer() | 	s.resetTimeoutTimer() | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -342,10 +359,10 @@ func newStateServerRelayed(b *stateBase) peerState { | |||||||
| func (s *stateServerRelayed) Name() string { return "server-relayed" } | func (s *stateServerRelayed) Name() string { return "server-relayed" } | ||||||
|  |  | ||||||
| func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState { | func (s *stateServerRelayed) OnPing(addr netip.AddrPort, p pingPacket) peerState { | ||||||
| 	if s.routingData.dataCipher == nil || p.SharedKey != s.routingData.dataCipher.Key() { | 	if s.staged.dataCipher == nil || p.SharedKey != s.staged.dataCipher.Key() { | ||||||
| 		s.logf("Got new shared key.") | 		s.logf("Got new shared key.") | ||||||
| 		s.routingData.up = true | 		s.staged.up = true | ||||||
| 		s.routingData.dataCipher = newDataCipherFromKey(p.SharedKey) | 		s.staged.dataCipher = newDataCipherFromKey(p.SharedKey) | ||||||
| 		s.publish() | 		s.publish() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -46,6 +46,12 @@ func (rp *remotePeer) supervise(conf m.PeerConfig) { | |||||||
|  |  | ||||||
| 		case pkt := <-rp.controlPackets: | 		case pkt := <-rp.controlPackets: | ||||||
| 			switch p := pkt.Payload.(type) { | 			switch p := pkt.Payload.(type) { | ||||||
|  | 			case synPacket: | ||||||
|  | 				nextState = curState.OnSyn(pkt.RemoteAddr, p) | ||||||
|  | 			case synAckPacket: | ||||||
|  | 				nextState = curState.OnSynAck(pkt.RemoteAddr, p) | ||||||
|  | 			case ackPacket: | ||||||
|  | 				nextState = curState.OnAck(pkt.RemoteAddr, p) | ||||||
| 			case pingPacket: | 			case pingPacket: | ||||||
| 				nextState = curState.OnPing(pkt.RemoteAddr, p) | 				nextState = curState.OnPing(pkt.RemoteAddr, p) | ||||||
| 			case pongPacket: | 			case pongPacket: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user