sym-encryption #1
| @@ -1,9 +1,15 @@ | ||||
| package node | ||||
|  | ||||
| import "net/netip" | ||||
|  | ||||
| const ( | ||||
| 	bufferSize            = 1536 | ||||
| 	if_mtu                = 1200 | ||||
| 	if_mtu                = 1400 | ||||
| 	if_queue_len          = 2048 | ||||
| 	controlCipherOverhead = 16 | ||||
| 	dataCipherOverhead    = 16 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	zeroAddrPort = netip.AddrPort{} | ||||
| ) | ||||
|   | ||||
| @@ -3,8 +3,6 @@ package node | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| @@ -31,10 +29,6 @@ type controlPacket struct { | ||||
|  | ||||
| func (p *controlPacket) ParsePayload(buf []byte) (err error) { | ||||
| 	switch buf[0] { | ||||
| 	case packetTypePing: | ||||
| 		p.Payload, err = parsePingPacket(buf) | ||||
| 	case packetTypePong: | ||||
| 		p.Payload, err = parsePongPacket(buf) | ||||
| 	case packetTypeSyn: | ||||
| 		p.Payload, err = parseSynPacket(buf) | ||||
| 	case packetTypeSynAck: | ||||
| @@ -50,10 +44,9 @@ func (p *controlPacket) ParsePayload(buf []byte) (err error) { | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type synPacket struct { | ||||
| 	TraceID    uint64         // TraceID to match response w/ request. | ||||
| 	SharedKey  [32]byte       // Our shared key. | ||||
| 	ServerAddr netip.AddrPort // The address we're sending to. | ||||
| 	RelayIP    byte | ||||
| 	TraceID   uint64   // TraceID to match response w/ request. | ||||
| 	SharedKey [32]byte // Our shared key. | ||||
| 	RelayIP   byte | ||||
| } | ||||
|  | ||||
| func (p synPacket) Marshal(buf []byte) []byte { | ||||
| @@ -61,7 +54,6 @@ func (p synPacket) Marshal(buf []byte) []byte { | ||||
| 		Byte(packetTypeSyn). | ||||
| 		Uint64(p.TraceID). | ||||
| 		SharedKey(p.SharedKey). | ||||
| 		AddrPort(p.ServerAddr). | ||||
| 		Byte(p.RelayIP). | ||||
| 		Build() | ||||
| } | ||||
| @@ -70,7 +62,6 @@ func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Uint64(&p.TraceID). | ||||
| 		SharedKey(&p.SharedKey). | ||||
| 		AddrPort(&p.ServerAddr). | ||||
| 		Byte(&p.RelayIP). | ||||
| 		Error() | ||||
| 	return | ||||
| @@ -119,63 +110,3 @@ func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // A pingPacket is sent from a node acting as a client, to a node acting | ||||
| // as a server. It always contains the shared key the client is expecting | ||||
| // to use for data encryption with the server. | ||||
| type pingPacket struct { | ||||
| 	SentAt int64 // UnixMilli. // Not used. Use traceID. | ||||
| } | ||||
|  | ||||
| func newPingPacket() (pp pingPacket) { | ||||
| 	pp.SentAt = time.Now().UnixMilli() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (p pingPacket) Marshal(buf []byte) []byte { | ||||
| 	return newBinWriter(buf). | ||||
| 		Byte(packetTypePing). | ||||
| 		Int64(p.SentAt). | ||||
| 		Build() | ||||
| } | ||||
|  | ||||
| func parsePingPacket(buf []byte) (p pingPacket, err error) { | ||||
| 	err = newBinReader(buf[1:]). | ||||
| 		Int64(&p.SentAt). | ||||
| 		Error() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // A pongPacket is sent by a node in a server role in response to a pingPacket. | ||||
| type pongPacket struct { | ||||
| 	SentAt  int64 // UnixMilli. | ||||
| 	RecvdAt int64 // UnixMilli. | ||||
| } | ||||
|  | ||||
| func newPongPacket(sentAt int64) (pp pongPacket) { | ||||
| 	pp.SentAt = sentAt | ||||
| 	pp.RecvdAt = time.Now().UnixMilli() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (p pongPacket) Marshal(buf []byte) []byte { | ||||
| 	buf = buf[:17] | ||||
| 	buf[0] = packetTypePong | ||||
| 	*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt) | ||||
| 	*(*uint64)(unsafe.Pointer(&buf[9])) = uint64(p.RecvdAt) | ||||
|  | ||||
| 	return buf | ||||
| } | ||||
|  | ||||
| func parsePongPacket(buf []byte) (p pongPacket, err error) { | ||||
| 	if len(buf) != 17 { | ||||
| 		return p, errMalformedPacket | ||||
| 	} | ||||
| 	p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) | ||||
| 	p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9])) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -2,16 +2,13 @@ package node | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"net/netip" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestPacketSyn(t *testing.T) { | ||||
| 	in := synPacket{ | ||||
| 		TraceID:    newTraceID(), | ||||
| 		Direct:     true, | ||||
| 		ServerAddr: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 34), | ||||
| 		TraceID: newTraceID(), | ||||
| 	} | ||||
| 	rand.Read(in.SharedKey[:]) | ||||
|  | ||||
| @@ -54,38 +51,3 @@ func TestPacketAck(t *testing.T) { | ||||
| 		t.Fatal("\n", in, "\n", out) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPacketPing(t *testing.T) { | ||||
| 	sharedKey := make([]byte, 32) | ||||
| 	rand.Read(sharedKey) | ||||
|  | ||||
| 	buf := make([]byte, bufferSize) | ||||
|  | ||||
| 	p := newPingPacket([32]byte(sharedKey)) | ||||
| 	out := p.Marshal(buf) | ||||
|  | ||||
| 	p2, err := parsePingPacket(out) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(p, p2) { | ||||
| 		t.Fatal(p, p2) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPacketPong(t *testing.T) { | ||||
| 	buf := make([]byte, bufferSize) | ||||
|  | ||||
| 	p := newPongPacket(123566) | ||||
| 	out := p.Marshal(buf) | ||||
|  | ||||
| 	p2, err := parsePongPacket(out) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.DeepEqual(p, p2) { | ||||
| 		t.Fatal(p, p2) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -1,405 +0,0 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"math/rand" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type peerState interface { | ||||
| 	Name() string | ||||
| 	OnSyn(netip.AddrPort, synPacket) peerState | ||||
| 	OnSynAck(netip.AddrPort, synAckPacket) peerState | ||||
| 	OnAck(netip.AddrPort, ackPacket) peerState | ||||
|  | ||||
| 	OnPingTimer() peerState | ||||
| 	OnTimeoutTimer() peerState | ||||
|  | ||||
| 	// When the peer is updated, we reset. Handled by base state. | ||||
| 	OnPeerUpdate(*m.Peer) peerState | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateBase struct { | ||||
| 	// The purpose of this state machine is to manage this published data. | ||||
| 	published *atomic.Pointer[peerRoutingData] | ||||
| 	staged    peerRoutingData // Local copy of shared data. See publish(). | ||||
|  | ||||
| 	// The other remote peers. | ||||
| 	peers *remotePeers | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	localIP  byte | ||||
| 	localPub bool | ||||
| 	remoteIP byte | ||||
| 	privKey  []byte | ||||
| 	conn     *connWriter | ||||
|  | ||||
| 	// For sending to peer. | ||||
| 	counter *uint64 | ||||
|  | ||||
| 	// Mutable peer data. | ||||
| 	peer      *m.Peer | ||||
| 	remotePub bool | ||||
|  | ||||
| 	// Timers | ||||
| 	pingTimer    *time.Timer | ||||
| 	timeoutTimer *time.Timer | ||||
|  | ||||
| 	buf    []byte | ||||
| 	encBuf []byte | ||||
| } | ||||
|  | ||||
| func (sb *stateBase) Name() string { return "idle" } | ||||
|  | ||||
| func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState { | ||||
| 	// Both nil: no change. | ||||
| 	if peer == nil && s.peer == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// No change. | ||||
| 	if peer != nil && s.peer != nil && s.peer.Version == peer.Version { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return s.selectStateFromPeer(peer) | ||||
| } | ||||
|  | ||||
| func (s *stateBase) selectStateFromPeer(peer *m.Peer) peerState { | ||||
| 	s.peer = peer | ||||
| 	s.staged = peerRoutingData{} | ||||
| 	defer s.publish() | ||||
|  | ||||
| 	if peer == nil { | ||||
| 		return newStateNoPeer(s) | ||||
| 	} | ||||
|  | ||||
| 	s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) | ||||
| 	s.staged.dataCipher = newDataCipher() | ||||
|  | ||||
| 	s.resetPingTimer() | ||||
| 	s.resetTimeoutTimer() | ||||
|  | ||||
| 	ip, isValid := netip.AddrFromSlice(peer.PublicIP) | ||||
| 	if isValid { | ||||
| 		s.remotePub = true | ||||
| 		s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) | ||||
| 		s.staged.relay = peer.Mediator | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub == s.localPub { | ||||
| 		if s.localIP < s.remoteIP { | ||||
| 			return newStateServer2(s) | ||||
| 		} | ||||
| 		return newStateDialLocal(s) | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub { | ||||
| 		return newStateDialLocal(s) | ||||
| 	} | ||||
| 	return newStateServer2(s) | ||||
| } | ||||
|  | ||||
| func (s *stateBase) OnSyn(rAddr netip.AddrPort, p synPacket) peerState       { return nil } | ||||
| func (s *stateBase) OnSynAck(rAddr netip.AddrPort, p synAckPacket) peerState { return nil } | ||||
| func (s *stateBase) OnAck(rAddr netip.AddrPort, p ackPacket) peerState       { return nil } | ||||
|  | ||||
| func (s *stateBase) OnPingTimer() peerState    { return nil } | ||||
| func (s *stateBase) OnTimeoutTimer() peerState { return nil } | ||||
|  | ||||
| // Helpers. | ||||
|  | ||||
| func (s *stateBase) resetPingTimer()    { s.pingTimer.Reset(pingInterval) } | ||||
| func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) } | ||||
| func (s *stateBase) stopPingTimer()     { s.pingTimer.Stop() } | ||||
| func (s *stateBase) stopTimeoutTimer()  { s.timeoutTimer.Stop() } | ||||
|  | ||||
| func (s *stateBase) logf(msg string, args ...any) { | ||||
| 	log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) | ||||
| } | ||||
|  | ||||
| func (s *stateBase) publish() { | ||||
| 	data := s.staged | ||||
| 	s.published.Store(&data) | ||||
| } | ||||
|  | ||||
| func (s *stateBase) selectRelay() byte { | ||||
| 	possible := make([]byte, 0, 8) | ||||
| 	for i, peer := range s.peers { | ||||
| 		if peer.CanRelay() { | ||||
| 			possible = append(possible, byte(i)) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(possible) == 0 { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return possible[rand.Intn(len(possible))] | ||||
| } | ||||
|  | ||||
| func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | ||||
| 	buf := pkt.Marshal(s.buf) | ||||
| 	h := header{ | ||||
| 		StreamID: controlStreamID, | ||||
| 		Counter:  atomic.AddUint64(s.counter, 1), | ||||
| 		SourceIP: s.localIP, | ||||
| 		DestIP:   s.remoteIP, | ||||
| 	} | ||||
|  | ||||
| 	buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf) | ||||
| 	if s.staged.relayIP != 0 { | ||||
| 		s.peers[s.staged.relayIP].RelayFor(s.remoteIP, buf) | ||||
| 	} else { | ||||
| 		s.conn.WriteTo(buf, s.staged.remoteAddr) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateNoPeer struct{ *stateBase } | ||||
|  | ||||
| func newStateNoPeer(b *stateBase) *stateNoPeer { | ||||
| 	s := &stateNoPeer{b} | ||||
| 	s.pingTimer.Stop() | ||||
| 	s.timeoutTimer.Stop() | ||||
| 	s.publish() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateServer2 struct { | ||||
| 	*stateBase | ||||
| 	syn              synPacket | ||||
| 	publishedTraceID uint64 | ||||
| } | ||||
|  | ||||
| // TODO: Server should send SynAck packets on a loop. | ||||
| func newStateServer2(b *stateBase) peerState { | ||||
| 	s := &stateServer2{stateBase: b} | ||||
| 	s.resetTimeoutTimer() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateServer2) Name() string { return "server" } | ||||
|  | ||||
| func (s *stateServer2) OnSyn(remoteAddr netip.AddrPort, p synPacket) peerState { | ||||
| 	s.syn = p | ||||
| 	s.sendControlPacket(newSynAckPacket(p.TraceID)) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *stateServer2) OnAck(remoteAddr netip.AddrPort, p ackPacket) peerState { | ||||
| 	if p.TraceID != s.syn.TraceID { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	s.resetTimeoutTimer() | ||||
|  | ||||
| 	if p.TraceID == s.publishedTraceID { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// Pubish staged | ||||
| 	s.staged.remoteAddr = remoteAddr | ||||
| 	s.staged.dataCipher = newDataCipherFromKey(s.syn.SharedKey) | ||||
| 	s.staged.relayIP = s.syn.RelayIP | ||||
| 	s.staged.up = true | ||||
| 	s.publish() | ||||
|  | ||||
| 	s.publishedTraceID = p.TraceID | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *stateServer) OnTimeoutTimer() peerState { | ||||
| 	// TODO: We're down. | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateDialLocal struct { | ||||
| 	*stateBase | ||||
| 	syn synPacket | ||||
| } | ||||
|  | ||||
| func newStateDialLocal(b *stateBase) peerState { | ||||
| 	// s := stateDialLocal{stateBase: b} | ||||
| 	// TODO: check for peer local address. | ||||
| 	return newStateDialDirect(b) | ||||
| } | ||||
|  | ||||
| func (s *stateDialLocal) Name() string { return "dial-local" } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateDialDirect struct { | ||||
| 	*stateBase | ||||
| 	syn synPacket | ||||
| } | ||||
|  | ||||
| func newStateDialDirect(b *stateBase) peerState { | ||||
| 	// If we don't have an address, dial via relay. | ||||
| 	if b.staged.remoteAddr == zeroAddrPort { | ||||
| 		return newStateNoPeer(b) | ||||
| 	} | ||||
|  | ||||
| 	s := &stateDialDirect{stateBase: b} | ||||
| 	s.syn = synPacket{ | ||||
| 		TraceID:    newTraceID(), | ||||
| 		SharedKey:  s.staged.dataCipher.Key(), | ||||
| 		ServerAddr: b.staged.remoteAddr, | ||||
| 	} | ||||
|  | ||||
| 	s.sendControlPacket(s.syn) | ||||
| 	s.resetTimeoutTimer() | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateDialDirect) Name() string { return "dial-direct" } | ||||
|  | ||||
| func (s *stateDialDirect) OnSynAck(remoteAddr netip.AddrPort, p synAckPacket) peerState { | ||||
| 	if p.TraceID != s.syn.TraceID { | ||||
| 		// Hmm... | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	s.sendControlPacket(ackPacket{TraceID: s.syn.TraceID}) | ||||
| 	s.logf("GOT SYN-ACK! TODO!") | ||||
| 	// client should continue to respond to synAck packets from server. | ||||
| 	// return newStateClientConnected(s.stateBase, s.syn.TraceID) ... | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *stateDialDirect) OnTimeoutTimer() peerState { | ||||
| 	s.logf("Timeout when dialing") | ||||
| 	return newStateDialLocal(s.stateBase) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateClient struct { | ||||
| 	sharedKey [32]byte | ||||
| 	*stateBase | ||||
| } | ||||
|  | ||||
| func newStateClient(b *stateBase) peerState { | ||||
| 	s := &stateClient{stateBase: b} | ||||
| 	s.publish() | ||||
|  | ||||
| 	s.staged.dataCipher = newDataCipher() | ||||
| 	s.sharedKey = s.staged.dataCipher.Key() | ||||
|  | ||||
| 	s.sendControlPacket(newPingPacket()) | ||||
| 	s.resetPingTimer() | ||||
| 	s.resetTimeoutTimer() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClient) Name() string { return "client" } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateServer struct { | ||||
| 	*stateBase | ||||
| } | ||||
|  | ||||
| func newStateServer(b *stateBase) peerState { | ||||
| 	s := &stateServer{b} | ||||
| 	s.publish() | ||||
| 	s.stopPingTimer() | ||||
| 	s.stopTimeoutTimer() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateServer) Name() string { return "server" } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateSelectRelay struct { | ||||
| 	*stateBase | ||||
| } | ||||
|  | ||||
| func newStateSelectRelay(b *stateBase) peerState { | ||||
| 	s := &stateSelectRelay{stateBase: b} | ||||
| 	s.staged.dataCipher = nil | ||||
| 	s.staged.up = false | ||||
| 	s.publish() | ||||
|  | ||||
| 	if relay := s.selectRelay(); relay != 0 { | ||||
| 		s.staged.up = false | ||||
| 		s.staged.relayIP = relay | ||||
| 		return s.selectRole() | ||||
| 	} | ||||
|  | ||||
| 	s.resetPingTimer() | ||||
| 	s.stopTimeoutTimer() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateSelectRelay) selectRole() peerState { | ||||
| 	if s.localIP < s.remoteIP { | ||||
| 		return newStateServerRelayed(s.stateBase) | ||||
| 	} | ||||
| 	return newStateClientRelayed(s.stateBase) | ||||
| } | ||||
|  | ||||
| func (s *stateSelectRelay) Name() string { return "select-relay" } | ||||
|  | ||||
| func (s *stateSelectRelay) OnPingTimer() peerState { | ||||
| 	if relay := s.selectRelay(); relay != 0 { | ||||
| 		s.logf("Got relay IP: %d", relay) | ||||
| 		s.staged.relayIP = relay | ||||
| 		return s.selectRole() | ||||
| 	} | ||||
| 	s.resetPingTimer() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateClientRelayed struct { | ||||
| 	sharedKey [32]byte | ||||
| 	*stateBase | ||||
| } | ||||
|  | ||||
| func newStateClientRelayed(b *stateBase) peerState { | ||||
| 	s := &stateClientRelayed{stateBase: b} | ||||
|  | ||||
| 	s.staged.dataCipher = newDataCipher() | ||||
| 	s.sharedKey = s.staged.dataCipher.Key() | ||||
| 	s.publish() | ||||
|  | ||||
| 	s.sendControlPacket(newPingPacket()) | ||||
| 	s.resetPingTimer() | ||||
| 	s.resetTimeoutTimer() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateClientRelayed) Name() string { return "client-relayed" } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateServerRelayed struct { | ||||
| 	*stateBase | ||||
| } | ||||
|  | ||||
| func newStateServerRelayed(b *stateBase) peerState { | ||||
| 	s := &stateServerRelayed{b} | ||||
| 	s.stopPingTimer() | ||||
| 	s.resetTimeoutTimer() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *stateServerRelayed) Name() string { return "server-relayed" } | ||||
|  | ||||
| func (s *stateServerRelayed) OnTimeoutTimer() peerState { | ||||
| 	return newStateSelectRelay(s.stateBase) | ||||
| } | ||||
							
								
								
									
										276
									
								
								node/peer-super-states.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										276
									
								
								node/peer-super-states.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,276 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"math/rand" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) noPeer() stateFunc { | ||||
| 	return s.peerUpdate(<-s.peerUpdates) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) peerUpdate(peer *m.Peer) stateFunc { | ||||
| 	return func() stateFunc { return s._peerUpdate(peer) } | ||||
| } | ||||
|  | ||||
| func (s *peerSuper) _peerUpdate(peer *m.Peer) stateFunc { | ||||
| 	defer s.publish() | ||||
|  | ||||
| 	s.peer = peer | ||||
| 	s.staged = peerRoutingData{} | ||||
|  | ||||
| 	if s.peer == nil { | ||||
| 		return s.noPeer | ||||
| 	} | ||||
|  | ||||
| 	s.staged.controlCipher = newControlCipher(s.privKey, peer.EncPubKey) | ||||
| 	s.staged.dataCipher = newDataCipher() | ||||
|  | ||||
| 	if ip, isValid := netip.AddrFromSlice(peer.PublicIP); isValid { | ||||
| 		s.remotePub = true | ||||
| 		s.staged.relay = peer.Mediator | ||||
| 		s.staged.remoteAddr = netip.AddrPortFrom(ip, peer.Port) | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub == s.localPub { | ||||
| 		if s.localIP < s.remoteIP { | ||||
| 			return s.serverAccept | ||||
| 		} | ||||
| 		return s.clientInit | ||||
| 	} | ||||
|  | ||||
| 	if s.remotePub { | ||||
| 		return s.clientInit | ||||
| 	} | ||||
| 	return s.serverAccept | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) serverAccept() stateFunc { | ||||
| 	s.logf("STATE: server-accept") | ||||
| 	s.staged.up = false | ||||
| 	s.staged.dataCipher = nil | ||||
| 	s.staged.remoteAddr = zeroAddrPort | ||||
| 	s.staged.relayIP = 0 | ||||
| 	s.publish() | ||||
|  | ||||
| 	var syn synPacket | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case synPacket: | ||||
| 				syn = p | ||||
| 				s.staged.remoteAddr = pkt.RemoteAddr | ||||
| 				s.staged.dataCipher = newDataCipherFromKey(syn.SharedKey) | ||||
| 				s.staged.relayIP = syn.RelayIP | ||||
| 				s.publish() | ||||
| 				s.sendControlPacket(newSynAckPacket(p.TraceID)) | ||||
|  | ||||
| 			case ackPacket: | ||||
| 				if p.TraceID != syn.TraceID { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				// Publish. | ||||
| 				return s.serverConnected(syn.TraceID) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) serverConnected(traceID uint64) stateFunc { | ||||
| 	s.logf("STATE: server-connected") | ||||
| 	s.staged.up = true | ||||
| 	s.publish() | ||||
| 	return func() stateFunc { | ||||
| 		return s._serverConnected(traceID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *peerSuper) _serverConnected(traceID uint64) stateFunc { | ||||
|  | ||||
| 	timeoutTimer := time.NewTimer(timeoutInterval) | ||||
| 	defer timeoutTimer.Stop() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case ackPacket: | ||||
| 				if p.TraceID != traceID { | ||||
| 					return s.serverAccept | ||||
| 				} | ||||
|  | ||||
| 				s.sendControlPacket(ackPacket{TraceID: traceID}) | ||||
| 				timeoutTimer.Reset(timeoutInterval) | ||||
| 			} | ||||
|  | ||||
| 		case <-timeoutTimer.C: | ||||
| 			s.logf("server timeout") | ||||
| 			return s.serverAccept | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) clientInit() stateFunc { | ||||
| 	s.logf("STATE: client-init") | ||||
| 	if !s.remotePub { | ||||
| 		// TODO: Check local discovery for IP. | ||||
| 		// TODO: Attempt UDP hole punch. | ||||
| 		// TODO: client-relayed | ||||
| 		return s.clientSelectRelay | ||||
| 	} | ||||
|  | ||||
| 	return s.clientDial | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) clientSelectRelay() stateFunc { | ||||
| 	s.logf("STATE: client-select-relay") | ||||
|  | ||||
| 	timer := time.NewTimer(0) | ||||
| 	defer timer.Stop() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case <-timer.C: | ||||
| 			ip := s.selectRelayIP() | ||||
| 			if ip != 0 { | ||||
| 				s.logf("Got relay: %d", ip) | ||||
| 				s.staged.relayIP = ip | ||||
| 				s.publish() | ||||
| 				return s.clientDial | ||||
| 			} | ||||
|  | ||||
| 			s.logf("No relay available.") | ||||
| 			timer.Reset(pingInterval) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *peerSuper) selectRelayIP() byte { | ||||
| 	possible := make([]byte, 0, 8) | ||||
| 	for i, peer := range s.peers { | ||||
| 		if peer.CanRelay() { | ||||
| 			possible = append(possible, byte(i)) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(possible) == 0 { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return possible[rand.Intn(len(possible))] | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) clientDial() stateFunc { | ||||
| 	s.logf("STATE: client-dial") | ||||
|  | ||||
| 	var ( | ||||
| 		syn = synPacket{ | ||||
| 			TraceID:   newTraceID(), | ||||
| 			SharedKey: s.staged.dataCipher.Key(), | ||||
| 			RelayIP:   s.staged.relayIP, | ||||
| 		} | ||||
|  | ||||
| 		timeout = time.NewTimer(dialTimeout) | ||||
| 	) | ||||
|  | ||||
| 	defer timeout.Stop() | ||||
|  | ||||
| 	s.sendControlPacket(syn) | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
|  | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
| 			case synAckPacket: | ||||
| 				if p.TraceID != syn.TraceID { | ||||
| 					continue // Hmm... | ||||
| 				} | ||||
| 				s.sendControlPacket(ackPacket{TraceID: syn.TraceID}) | ||||
| 				return s.clientConnected(syn.TraceID) | ||||
| 			} | ||||
|  | ||||
| 		case <-timeout.C: | ||||
| 			return s.clientInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) clientConnected(traceID uint64) stateFunc { | ||||
| 	s.logf("STATE: client-connected") | ||||
| 	s.staged.up = true | ||||
| 	s.publish() | ||||
|  | ||||
| 	return func() stateFunc { | ||||
| 		return s._clientConnected(traceID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *peerSuper) _clientConnected(traceID uint64) stateFunc { | ||||
|  | ||||
| 	pingTimer := time.NewTimer(pingInterval) | ||||
| 	timeoutTimer := time.NewTimer(timeoutInterval) | ||||
|  | ||||
| 	defer pingTimer.Stop() | ||||
| 	defer timeoutTimer.Stop() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case peer := <-s.peerUpdates: | ||||
| 			return s.peerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-s.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
|  | ||||
| 			case ackPacket: | ||||
| 				if p.TraceID != traceID { | ||||
| 					return s.clientInit | ||||
| 				} | ||||
| 				timeoutTimer.Reset(timeoutInterval) | ||||
| 			} | ||||
|  | ||||
| 		case <-pingTimer.C: | ||||
| 			s.sendControlPacket(ackPacket{TraceID: traceID}) | ||||
| 			pingTimer.Reset(pingInterval) | ||||
|  | ||||
| 		case <-timeoutTimer.C: | ||||
| 			s.logf("client timeout") | ||||
| 			return s.clientInit | ||||
|  | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										80
									
								
								node/peer-super.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								node/peer-super.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"sync/atomic" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type peerSuper struct { | ||||
| 	// The purpose of this state machine is to manage this published data. | ||||
| 	published *atomic.Pointer[peerRoutingData] | ||||
| 	staged    peerRoutingData // Local copy of shared data. See publish(). | ||||
|  | ||||
| 	// The other remote peers. | ||||
| 	peers *remotePeers | ||||
|  | ||||
| 	// Immutable data. | ||||
| 	localIP  byte | ||||
| 	localPub bool | ||||
| 	remoteIP byte | ||||
| 	privKey  []byte | ||||
| 	conn     *connWriter | ||||
|  | ||||
| 	// For sending to peer. | ||||
| 	counter *uint64 | ||||
|  | ||||
| 	// Mutable peer data. | ||||
| 	peer      *m.Peer | ||||
| 	remotePub bool | ||||
|  | ||||
| 	// Incoming events. | ||||
| 	peerUpdates    chan *m.Peer | ||||
| 	controlPackets chan controlPacket | ||||
|  | ||||
| 	// Buffers | ||||
| 	buf    []byte | ||||
| 	encBuf []byte | ||||
| } | ||||
|  | ||||
| type stateFunc func() stateFunc | ||||
|  | ||||
| func (s *peerSuper) Run() { | ||||
| 	state := s.noPeer | ||||
| 	for { | ||||
| 		state = state() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) logf(msg string, args ...any) { | ||||
| 	log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) publish() { | ||||
| 	data := s.staged | ||||
| 	s.published.Store(&data) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | ||||
| 	buf := pkt.Marshal(s.buf) | ||||
| 	h := header{ | ||||
| 		StreamID: controlStreamID, | ||||
| 		Counter:  atomic.AddUint64(s.counter, 1), | ||||
| 		SourceIP: s.localIP, | ||||
| 		DestIP:   s.remoteIP, | ||||
| 	} | ||||
|  | ||||
| 	buf = s.staged.controlCipher.Encrypt(h, buf, s.encBuf) | ||||
| 	if s.staged.relayIP != 0 { | ||||
| 		s.peers[s.staged.relayIP].RelayTo(s.remoteIP, buf) | ||||
| 	} else { | ||||
| 		s.conn.WriteTo(buf, s.staged.remoteAddr) | ||||
| 	} | ||||
| } | ||||
| @@ -15,55 +15,20 @@ const ( | ||||
| func (rp *remotePeer) supervise(conf m.PeerConfig) { | ||||
| 	defer panicHandler() | ||||
|  | ||||
| 	base := &stateBase{ | ||||
| 		published:    rp.published, | ||||
| 		peers:        rp.peers, | ||||
| 		localIP:      rp.localIP, | ||||
| 		remoteIP:     rp.remoteIP, | ||||
| 		privKey:      conf.EncPrivKey, | ||||
| 		localPub:     addrIsValid(conf.PublicIP), | ||||
| 		conn:         rp.conn, | ||||
| 		counter:      &rp.counter, | ||||
| 		pingTimer:    time.NewTimer(time.Second), | ||||
| 		timeoutTimer: time.NewTimer(time.Second), | ||||
| 		buf:          make([]byte, bufferSize), | ||||
| 		encBuf:       make([]byte, bufferSize), | ||||
| 	super := &peerSuper{ | ||||
| 		published:      rp.published, | ||||
| 		peers:          rp.peers, | ||||
| 		localIP:        rp.localIP, | ||||
| 		localPub:       addrIsValid(conf.PublicIP), | ||||
| 		remoteIP:       rp.remoteIP, | ||||
| 		privKey:        conf.EncPrivKey, | ||||
| 		conn:           rp.conn, | ||||
| 		counter:        &rp.counter, | ||||
| 		peerUpdates:    rp.peerUpdates, | ||||
| 		controlPackets: rp.controlPackets, | ||||
| 		buf:            make([]byte, bufferSize), | ||||
| 		encBuf:         make([]byte, bufferSize), | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		curState  peerState = newStateNoPeer(base) | ||||
| 		nextState peerState | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		nextState = nil | ||||
|  | ||||
| 		select { | ||||
| 		case peer := <-rp.peerUpdates: | ||||
| 			nextState = curState.OnPeerUpdate(peer) | ||||
|  | ||||
| 		case pkt := <-rp.controlPackets: | ||||
| 			switch p := pkt.Payload.(type) { | ||||
| 			case synPacket: | ||||
| 				nextState = curState.OnSyn(pkt.RemoteAddr, p) | ||||
| 			case synAckPacket: | ||||
| 				nextState = curState.OnSynAck(pkt.RemoteAddr, p) | ||||
| 			case ackPacket: | ||||
| 				nextState = curState.OnAck(pkt.RemoteAddr, p) | ||||
| 			default: | ||||
| 				// Unknown packet type. | ||||
| 			} | ||||
|  | ||||
| 		case <-base.pingTimer.C: | ||||
| 			nextState = curState.OnPingTimer() | ||||
|  | ||||
| 		case <-base.timeoutTimer.C: | ||||
| 			nextState = curState.OnTimeoutTimer() | ||||
| 		} | ||||
|  | ||||
| 		if nextState != nil { | ||||
| 			rp.logf("%s --> %s", curState.Name(), nextState.Name()) | ||||
| 			curState = nextState | ||||
| 		} | ||||
| 	} | ||||
| 	go super.Run() | ||||
| } | ||||
|   | ||||
							
								
								
									
										15
									
								
								node/peer.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								node/peer.go
									
									
									
									
									
								
							| @@ -41,6 +41,10 @@ type remotePeer struct { | ||||
| 	// Used for sending control and data packets. Atomic access only. | ||||
| 	counter uint64 | ||||
|  | ||||
| 	// Only accessed in HandlePeerUpdate. Used to determine if we should send | ||||
| 	// the peer update to the peerSuper. | ||||
| 	peerVersion int64 | ||||
|  | ||||
| 	// For communicating with the supervisor thread. | ||||
| 	peerUpdates    chan *m.Peer | ||||
| 	controlPackets chan controlPacket | ||||
| @@ -75,7 +79,12 @@ func (rp *remotePeer) logf(msg string, args ...any) { | ||||
| } | ||||
|  | ||||
| func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) { | ||||
| 	rp.peerUpdates <- peer | ||||
| 	if peer == nil { | ||||
| 		rp.peerUpdates <- peer | ||||
| 	} else if peer.Version != rp.peerVersion { | ||||
| 		rp.peerVersion = peer.Version | ||||
| 		rp.peerUpdates <- peer | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -209,7 +218,7 @@ func (rp *remotePeer) HandleInterfacePacket(data []byte) { | ||||
| 	enc := routingData.dataCipher.Encrypt(h, data, rp.encryptBuf) | ||||
|  | ||||
| 	if routingData.relayIP != 0 { | ||||
| 		rp.peers[routingData.relayIP].RelayFor(rp.remoteIP, enc) | ||||
| 		rp.peers[routingData.relayIP].RelayTo(rp.remoteIP, enc) | ||||
| 	} else { | ||||
| 		rp.SendData(data) | ||||
| 	} | ||||
| @@ -224,7 +233,7 @@ func (rp *remotePeer) CanRelay() bool { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *remotePeer) RelayFor(destIP byte, data []byte) { | ||||
| func (rp *remotePeer) RelayTo(destIP byte, data []byte) { | ||||
| 	rp.encryptAndSend(relayStreamID, destIP, data) | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,7 +0,0 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| ) | ||||
|  | ||||
| var zeroAddrPort = netip.AddrPort{} | ||||
		Reference in New Issue
	
	Block a user