WIP: Working
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| ## Roadmap | ||||
|  | ||||
| * Rename Mediator -> Relay | ||||
| * Node: use symmetric encryption after handshake | ||||
| * AEAD-AES uses a 12 byte nonce. We need to shrink the header: | ||||
|   * Remove Forward and replace it with a HeaderFlags bitfield. | ||||
|   | ||||
							
								
								
									
										8
									
								
								node/addrutil.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								node/addrutil.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| package node | ||||
|  | ||||
| import "net/netip" | ||||
|  | ||||
| func addrIsValid(in []byte) bool { | ||||
| 	_, ok := netip.AddrFromSlice(in) | ||||
| 	return ok | ||||
| } | ||||
							
								
								
									
										26
									
								
								node/cipher-control.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								node/cipher-control.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| package node | ||||
|  | ||||
| import "golang.org/x/crypto/nacl/box" | ||||
|  | ||||
| type controlCipher struct { | ||||
| 	sharedKey [32]byte | ||||
| } | ||||
|  | ||||
| func newControlCipher(privKey, pubKey []byte) *controlCipher { | ||||
| 	shared := [32]byte{} | ||||
| 	box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) | ||||
| 	return &controlCipher{shared} | ||||
| } | ||||
|  | ||||
| func (cc *controlCipher) Encrypt(h xHeader, data, out []byte) []byte { | ||||
| 	const s = controlHeaderSize | ||||
| 	out = out[:s+controlCipherOverhead+len(data)] | ||||
| 	h.Marshal(out[:s]) | ||||
| 	box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) | ||||
| 	return out | ||||
| } | ||||
|  | ||||
| func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { | ||||
| 	const s = controlHeaderSize | ||||
| 	return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) | ||||
| } | ||||
| @@ -3,12 +3,13 @@ package node | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/rand" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"golang.org/x/crypto/nacl/box" | ||||
| ) | ||||
| 
 | ||||
| func newRoutingCipherForTesting() (c1, c2 routingCipher) { | ||||
| func newControlCipherForTesting() (c1, c2 *controlCipher) { | ||||
| 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| @@ -19,14 +20,14 @@ func newRoutingCipherForTesting() (c1, c2 routingCipher) { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 
 | ||||
| 	return newRoutingCipher(privKey1[:], pubKey2[:]), | ||||
| 		newRoutingCipher(privKey2[:], pubKey1[:]) | ||||
| 	return newControlCipher(privKey1[:], pubKey2[:]), | ||||
| 		newControlCipher(privKey2[:], pubKey1[:]) | ||||
| } | ||||
| 
 | ||||
| func TestRoutingCipher(t *testing.T) { | ||||
| 	c1, c2 := newRoutingCipherForTesting() | ||||
| func TestControlCipher(t *testing.T) { | ||||
| 	c1, c2 := newControlCipherForTesting() | ||||
| 
 | ||||
| 	maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) | ||||
| 	maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) | ||||
| 	rand.Read(maxSizePlaintext) | ||||
| 
 | ||||
| 	testCases := [][]byte{ | ||||
| @@ -40,6 +41,7 @@ func TestRoutingCipher(t *testing.T) { | ||||
| 
 | ||||
| 	for _, plaintext := range testCases { | ||||
| 		h1 := xHeader{ | ||||
| 			StreamID: controlStreamID, | ||||
| 			Counter:  235153, | ||||
| 			SourceIP: 4, | ||||
| 			DestIP:   88, | ||||
| @@ -49,6 +51,12 @@ func TestRoutingCipher(t *testing.T) { | ||||
| 
 | ||||
| 		encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||
| 
 | ||||
| 		h2 := xHeader{} | ||||
| 		h2.Parse(encrypted) | ||||
| 		if !reflect.DeepEqual(h1, h2) { | ||||
| 			t.Fatal(h1, h2) | ||||
| 		} | ||||
| 
 | ||||
| 		decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) | ||||
| 		if !ok { | ||||
| 			t.Fatal(ok) | ||||
| @@ -60,9 +68,9 @@ func TestRoutingCipher(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestRoutingCipher_ShortCiphertext(t *testing.T) { | ||||
| 	c1, _ := newRoutingCipherForTesting() | ||||
| 	shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1) | ||||
| func TestControlCipher_ShortCiphertext(t *testing.T) { | ||||
| 	c1, _ := newControlCipherForTesting() | ||||
| 	shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) | ||||
| 	rand.Read(shortText) | ||||
| 	_, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) | ||||
| 	if ok { | ||||
| @@ -70,15 +78,15 @@ func TestRoutingCipher_ShortCiphertext(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func BenchmarkRoutingCipher_Encrypt(b *testing.B) { | ||||
| 	c1, _ := newRoutingCipherForTesting() | ||||
| func BenchmarkControlCipher_Encrypt(b *testing.B) { | ||||
| 	c1, _ := newControlCipherForTesting() | ||||
| 	h1 := xHeader{ | ||||
| 		Counter:  235153, | ||||
| 		SourceIP: 4, | ||||
| 		DestIP:   88, | ||||
| 	} | ||||
| 
 | ||||
| 	plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) | ||||
| 	plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) | ||||
| 	rand.Read(plaintext) | ||||
| 
 | ||||
| 	encrypted := make([]byte, bufferSize) | ||||
| @@ -89,8 +97,8 @@ func BenchmarkRoutingCipher_Encrypt(b *testing.B) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func BenchmarkRoutingCipher_Decrypt(b *testing.B) { | ||||
| 	c1, c2 := newRoutingCipherForTesting() | ||||
| func BenchmarkControlCipher_Decrypt(b *testing.B) { | ||||
| 	c1, c2 := newControlCipherForTesting() | ||||
| 
 | ||||
| 	h1 := xHeader{ | ||||
| 		Counter:  235153, | ||||
| @@ -98,7 +106,7 @@ func BenchmarkRoutingCipher_Decrypt(b *testing.B) { | ||||
| 		DestIP:   88, | ||||
| 	} | ||||
| 
 | ||||
| 	plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) | ||||
| 	plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) | ||||
| 	rand.Read(plaintext) | ||||
| 
 | ||||
| 	encrypted := make([]byte, bufferSize) | ||||
| @@ -6,22 +6,23 @@ import ( | ||||
| 	"crypto/rand" | ||||
| ) | ||||
|  | ||||
| // TODO: Use [32]byte for simplicity everywhere. | ||||
| type dataCipher struct { | ||||
| 	key  []byte | ||||
| 	key  [32]byte | ||||
| 	aead cipher.AEAD | ||||
| } | ||||
|  | ||||
| func newDataCipher() *dataCipher { | ||||
| 	key := make([]byte, 32) | ||||
| 	if _, err := rand.Read(key); err != nil { | ||||
| 	key := [32]byte{} | ||||
| 	if _, err := rand.Read(key[:]); err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return newDataCipherFromKey(key) | ||||
| } | ||||
|  | ||||
| // key must be 32 bytes. | ||||
| func newDataCipherFromKey(key []byte) *dataCipher { | ||||
| 	block, err := aes.NewCipher(key) | ||||
| func newDataCipherFromKey(key [32]byte) *dataCipher { | ||||
| 	block, err := aes.NewCipher(key[:]) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| @@ -34,14 +35,14 @@ func newDataCipherFromKey(key []byte) *dataCipher { | ||||
| 	return &dataCipher{key: key, aead: aead} | ||||
| } | ||||
|  | ||||
| func (sc *dataCipher) Key() []byte { | ||||
| func (sc *dataCipher) Key() [32]byte { | ||||
| 	return sc.key | ||||
| } | ||||
|  | ||||
| func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte { | ||||
| 	const s = dataHeaderSize | ||||
| 	out = out[:s+dataCipherOverhead+len(data)] | ||||
| 	h.Marshal(dataStreamID, out[:s]) | ||||
| 	h.Marshal(out[:s]) | ||||
| 	sc.aead.Seal(out[s:s], out[:s], data, nil) | ||||
| 	return out | ||||
| } | ||||
|   | ||||
| @@ -23,6 +23,7 @@ func TestDataCipher(t *testing.T) { | ||||
|  | ||||
| 	for _, plaintext := range testCases { | ||||
| 		h1 := xHeader{ | ||||
| 			StreamID: dataStreamID, | ||||
| 			Counter:  235153, | ||||
| 			SourceIP: 4, | ||||
| 			DestIP:   88, | ||||
|   | ||||
| @@ -1,26 +0,0 @@ | ||||
| package node | ||||
|  | ||||
| import "golang.org/x/crypto/nacl/box" | ||||
|  | ||||
| type routingCipher struct { | ||||
| 	sharedKey [32]byte | ||||
| } | ||||
|  | ||||
| func newRoutingCipher(privKey, pubKey []byte) routingCipher { | ||||
| 	shared := [32]byte{} | ||||
| 	box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) | ||||
| 	return routingCipher{shared} | ||||
| } | ||||
|  | ||||
| func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte { | ||||
| 	const s = routingHeaderSize | ||||
| 	out = out[:s+routingCipherOverhead+len(data)] | ||||
| 	h.Marshal(routingStreamID, out[:s]) | ||||
| 	box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey) | ||||
| 	return out | ||||
| } | ||||
|  | ||||
| func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { | ||||
| 	const s = routingHeaderSize | ||||
| 	return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey) | ||||
| } | ||||
							
								
								
									
										43
									
								
								node/conn.go
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								node/conn.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| @@ -9,6 +10,48 @@ import ( | ||||
| 	"vppn/fasttime" | ||||
| ) | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type connWriter2 struct { | ||||
| 	lock sync.Mutex | ||||
| 	conn *net.UDPConn | ||||
| } | ||||
|  | ||||
| func newConnWriter2(conn *net.UDPConn) *connWriter2 { | ||||
| 	return &connWriter2{conn: conn} | ||||
| } | ||||
|  | ||||
| func (w *connWriter2) WriteTo(packet []byte, addr netip.AddrPort) { | ||||
| 	w.lock.Lock() | ||||
| 	if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { | ||||
| 		log.Fatalf("Failed to write to UDP port: %v", err) | ||||
| 	} | ||||
| 	w.lock.Unlock() | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type ifWriter struct { | ||||
| 	lock  sync.Mutex | ||||
| 	iface io.ReadWriteCloser | ||||
| } | ||||
|  | ||||
| func newIFWriter(iface io.ReadWriteCloser) *ifWriter { | ||||
| 	return &ifWriter{iface: iface} | ||||
| } | ||||
|  | ||||
| func (w *ifWriter) Write(packet []byte) { | ||||
| 	w.lock.Lock() | ||||
| 	if _, err := w.iface.Write(packet); err != nil { | ||||
| 		log.Fatalf("Failed to write to interface: %v", err) | ||||
| 	} | ||||
| 	w.lock.Unlock() | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // TODO: Delete below?? | ||||
|  | ||||
| type connWriter struct { | ||||
| 	*net.UDPConn | ||||
| 	lock     sync.Mutex | ||||
|   | ||||
| @@ -5,30 +5,33 @@ import "unsafe" | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| const ( | ||||
| 	routingStreamID       = 2 | ||||
| 	routingHeaderSize     = 24 | ||||
| 	routingCipherOverhead = 16 | ||||
| 	controlStreamID       = 2 | ||||
| 	controlHeaderSize     = 24 | ||||
| 	controlCipherOverhead = 16 | ||||
|  | ||||
| 	dataStreamID       = 1 | ||||
| 	dataHeaderSize     = 12 | ||||
| 	dataCipherOverhead = 16 | ||||
|  | ||||
| 	forwardStreamID = 3 | ||||
| ) | ||||
|  | ||||
| // TODO: Rename | ||||
| type xHeader struct { | ||||
| 	StreamID byte | ||||
| 	Counter  uint64 // Init with fasttime.Now() << 30 to ensure monotonic. | ||||
| 	SourceIP byte | ||||
| 	DestIP   byte | ||||
| } | ||||
|  | ||||
| func (h *xHeader) Parse(b []byte) { | ||||
| 	h.StreamID = b[0] | ||||
| 	h.Counter = *(*uint64)(unsafe.Pointer(&b[1])) | ||||
| 	h.SourceIP = b[9] | ||||
| 	h.DestIP = b[10] | ||||
| } | ||||
|  | ||||
| func (h *xHeader) Marshal(streamID byte, buf []byte) { | ||||
| 	buf[0] = streamID | ||||
| func (h *xHeader) Marshal(buf []byte) { | ||||
| 	buf[0] = h.StreamID | ||||
| 	*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter | ||||
| 	buf[9] = h.SourceIP | ||||
| 	buf[10] = h.DestIP | ||||
| @@ -40,7 +43,7 @@ func (h *xHeader) Marshal(streamID byte, buf []byte) { | ||||
| const ( | ||||
| 	headerSize    = 24 | ||||
| 	streamData    = 1 | ||||
| 	streamRouting = 2 | ||||
| 	streamControl = 2 | ||||
| ) | ||||
|  | ||||
| type header struct { | ||||
|   | ||||
| @@ -3,18 +3,17 @@ package node | ||||
| import "testing" | ||||
|  | ||||
| func TestHeaderMarshalParse(t *testing.T) { | ||||
| 	nIn := header{ | ||||
| 	nIn := xHeader{ | ||||
| 		StreamID: 23, | ||||
| 		Counter:  3212, | ||||
| 		SourceIP: 34, | ||||
| 		DestIP:   200, | ||||
| 		Forward:  1, | ||||
| 		Stream:   44, | ||||
| 	} | ||||
|  | ||||
| 	buf := make([]byte, headerSize) | ||||
| 	nIn.Marshal(buf) | ||||
|  | ||||
| 	nOut := header{} | ||||
| 	nOut := xHeader{} | ||||
| 	nOut.Parse(buf) | ||||
| 	if nIn != nOut { | ||||
| 		t.Fatal(nIn, nOut) | ||||
|   | ||||
							
								
								
									
										59
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -102,15 +102,19 @@ func main(netName, listenIP string, port uint16) { | ||||
| 		log.Fatalf("Failed to open UDP port: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	routing := newRoutingTable() | ||||
| 	connWriter := newConnWriter2(conn) | ||||
| 	ifWriter := newIFWriter(iface) | ||||
|  | ||||
| 	w := newConnWriter(conn, conf.PeerIP, routing) | ||||
| 	r := newConnReader(conn, conf.PeerIP, routing) | ||||
| 	peers := remotePeers{} | ||||
|  | ||||
| 	router := newRouter(netName, conf, routing, w) | ||||
| 	for i := range peers { | ||||
| 		peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter) | ||||
| 	} | ||||
|  | ||||
| 	go newHubPoller(netName, conf, peers).Run() | ||||
| 	go readFromConn(conn, peers) | ||||
| 	readFromIFace(iface, peers) | ||||
|  | ||||
| 	go nodeConnReader(r, w, iface, router) | ||||
| 	nodeIFaceReader(w, iface, router) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -127,43 +131,39 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) { | ||||
| func readFromConn(conn *net.UDPConn, peers remotePeers) { | ||||
|  | ||||
| 	defer panicHandler() | ||||
|  | ||||
| 	var ( | ||||
| 		remoteAddr netip.AddrPort | ||||
| 		h          header | ||||
| 		n          int | ||||
| 		err        error | ||||
| 		buf        = make([]byte, bufferSize) | ||||
| 		data       []byte | ||||
| 		err        error | ||||
| 		h          xHeader | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		remoteAddr, h, data = r.Read(buf) | ||||
|  | ||||
| 		if h.Forward != 0 { | ||||
| 			w.Forward(h.DestIP, data) | ||||
| 			continue | ||||
| 		n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize]) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("Failed to read from UDP port: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		switch h.Stream { | ||||
| 		data = buf[:n] | ||||
|  | ||||
| 		case streamData: | ||||
| 			if _, err = iface.Write(data); err != nil { | ||||
| 				log.Printf("Malformed data from peer %d: %v", h.SourceIP, err) | ||||
| 			} | ||||
|  | ||||
| 		case streamRouting: | ||||
| 			router.HandlePacket(h.SourceIP, remoteAddr, data) | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("Dropping unknown stream: %d", h.Stream) | ||||
| 		if n < headerSize { | ||||
| 			continue // Packet it soo short. | ||||
| 		} | ||||
|  | ||||
| 		h.Parse(data) | ||||
| 		peers[h.SourceIP].HandlePacket(remoteAddr, h, data) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { | ||||
| func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) { | ||||
|  | ||||
| 	var ( | ||||
| 		buf      = make([]byte, bufferSize) | ||||
| @@ -173,16 +173,11 @@ func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
|  | ||||
| 		packet, remoteIP, err = readNextPacket(iface, buf) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("Failed to read from interface: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		if remoteIP == w.localIP { | ||||
| 			continue // Don't write to self. | ||||
| 		} | ||||
|  | ||||
| 		w.WriteTo(remoteIP, streamData, packet) | ||||
| 		peers[remoteIP].SendData(packet) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -16,10 +16,10 @@ const ( | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type packetWrapper struct { | ||||
| type controlPacket struct { | ||||
| 	SrcIP      byte | ||||
| 	RemoteAddr netip.AddrPort | ||||
| 	Packet     any | ||||
| 	Payload    any | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -46,13 +46,13 @@ func (p pingPacket) Marshal(buf []byte) []byte { | ||||
| 	return buf | ||||
| } | ||||
|  | ||||
| func (p *pingPacket) Parse(buf []byte) error { | ||||
| func parsePingPacket(buf []byte) (p pingPacket, err error) { | ||||
| 	if len(buf) != 41 { | ||||
| 		return errMalformedPacket | ||||
| 		return p, errMalformedPacket | ||||
| 	} | ||||
| 	p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) | ||||
| 	copy(p.SharedKey[:], buf[9:41]) | ||||
| 	return nil | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| @@ -78,12 +78,11 @@ func (p pongPacket) Marshal(buf []byte) []byte { | ||||
| 	return buf | ||||
| } | ||||
|  | ||||
| func (p *pongPacket) Parse(buf []byte) error { | ||||
| func parsePongPacket(buf []byte) (p pongPacket, err error) { | ||||
| 	if len(buf) != 17 { | ||||
| 		return errMalformedPacket | ||||
| 		return p, errMalformedPacket | ||||
| 	} | ||||
| 	p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) | ||||
| 	p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9])) | ||||
|  | ||||
| 	return nil | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -15,8 +15,8 @@ func TestPacketPing(t *testing.T) { | ||||
| 	p := newPingPacket(sharedKey) | ||||
| 	out := p.Marshal(buf) | ||||
|  | ||||
| 	p2 := pingPacket{} | ||||
| 	if err := p2.Parse(out); err != nil { | ||||
| 	p2, err := parsePingPacket(out) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| @@ -31,8 +31,8 @@ func TestPacketPong(t *testing.T) { | ||||
| 	p := newPongPacket(123566) | ||||
| 	out := p.Marshal(buf) | ||||
|  | ||||
| 	p2 := pongPacket{} | ||||
| 	if err := p2.Parse(out); err != nil { | ||||
| 	p2, err := parsePongPacket(out) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
|   | ||||
							
								
								
									
										97
									
								
								node/peer-pollhub.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								node/peer-pollhub.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type hubPoller struct { | ||||
| 	netName string | ||||
| 	localIP byte | ||||
| 	client  *http.Client | ||||
| 	req     *http.Request | ||||
| 	peers   remotePeers | ||||
| } | ||||
|  | ||||
| func newHubPoller(netName string, conf m.PeerConfig, peers remotePeers) *hubPoller { | ||||
| 	u, err := url.Parse(conf.HubAddress) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err) | ||||
| 	} | ||||
| 	u.Path = "/peer/fetch-state/" | ||||
|  | ||||
| 	client := &http.Client{Timeout: 8 * time.Second} | ||||
|  | ||||
| 	req := &http.Request{ | ||||
| 		Method: http.MethodGet, | ||||
| 		URL:    u, | ||||
| 		Header: http.Header{}, | ||||
| 	} | ||||
| 	req.SetBasicAuth("", conf.APIKey) | ||||
|  | ||||
| 	return &hubPoller{ | ||||
| 		netName: netName, | ||||
| 		localIP: conf.PeerIP, | ||||
| 		client:  client, | ||||
| 		req:     req, | ||||
| 		peers:   peers, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (hp *hubPoller) Run() { | ||||
| 	defer panicHandler() | ||||
|  | ||||
| 	state, err := loadNetworkState(hp.netName) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to load network state: %v", err) | ||||
| 		log.Printf("Polling hub...") | ||||
| 		hp.pollHub() | ||||
| 	} else { | ||||
| 		hp.applyNetworkState(state) | ||||
| 	} | ||||
|  | ||||
| 	for range time.Tick(64 * time.Second) { | ||||
| 		hp.pollHub() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (hp *hubPoller) pollHub() { | ||||
| 	var state m.NetworkState | ||||
|  | ||||
| 	log.Printf("Fetching peer state...") | ||||
| 	resp, err := hp.client.Do(hp.req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to fetch peer state: %v", err) | ||||
| 		return | ||||
| 	} | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	_ = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to read body from hub: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if err := json.Unmarshal(body, &state); err != nil { | ||||
| 		log.Printf("Failed to unmarshal response from hub: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	hp.applyNetworkState(state) | ||||
|  | ||||
| 	if err := storeNetworkState(hp.netName, state); err != nil { | ||||
| 		log.Printf("Failed to store network state: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (hp *hubPoller) applyNetworkState(state m.NetworkState) { | ||||
| 	for i := range state.Peers { | ||||
| 		if i != int(hp.localIP) { | ||||
| 			hp.peers[i].HandlePeerUpdate(state.Peers[i]) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										197
									
								
								node/peer-supervisor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								node/peer-supervisor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,197 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	connectTimeout  = 6 * time.Second | ||||
| 	pingInterval    = 6 * time.Second | ||||
| 	timeoutInterval = 20 * time.Second | ||||
| ) | ||||
|  | ||||
| type stateFunc func() stateFunc | ||||
|  | ||||
| type peerSuper struct { | ||||
| 	*remotePeer | ||||
|  | ||||
| 	peer         *m.Peer | ||||
| 	remotePublic bool | ||||
| 	peerData     peerData | ||||
|  | ||||
| 	pktBuf []byte | ||||
| 	encBuf []byte | ||||
| } | ||||
|  | ||||
| func newPeerSuper(rp *remotePeer) *peerSuper { | ||||
| 	return &peerSuper{ | ||||
| 		remotePeer: rp, | ||||
| 		peer:       nil, | ||||
| 		pktBuf:     make([]byte, bufferSize), | ||||
| 		encBuf:     make([]byte, bufferSize), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (rp *peerSuper) Run() { | ||||
| 	defer panicHandler() | ||||
| 	state := rp.stateInit | ||||
| 	for { | ||||
| 		state = state() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *peerSuper) stateInit() stateFunc { | ||||
| 	//rp.logf("STATE: Init") | ||||
| 	x := peerData{} | ||||
| 	rp.shared.Store(&x) | ||||
|  | ||||
| 	rp.peerData.controlCipher = nil | ||||
| 	rp.peerData.dataCipher = nil | ||||
| 	rp.peerData.remoteAddr = zeroAddrPort | ||||
|  | ||||
| 	if rp.peer == nil { | ||||
| 		return rp.stateDisconnected | ||||
| 	} | ||||
|  | ||||
| 	var addr netip.Addr | ||||
| 	addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) | ||||
| 	if rp.remotePublic { | ||||
| 		rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) | ||||
| 	} | ||||
|  | ||||
| 	rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) | ||||
|  | ||||
| 	return rp.stateSelectRole() | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *peerSuper) stateDisconnected() stateFunc { | ||||
| 	//rp.logf("STATE: Disconnected") | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-rp.controlPackets: | ||||
| 			// Drop | ||||
| 		case rp.peer = <-rp.peerUpdates: | ||||
| 			return rp.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *peerSuper) stateSelectRole() stateFunc { | ||||
| 	rp.logf("STATE: SelectRole") | ||||
|  | ||||
| 	if !rp.localPublic && !rp.remotePublic { | ||||
| 		// TODO! | ||||
| 		return rp.stateDisconnected | ||||
| 	} | ||||
|  | ||||
| 	if !rp.localPublic { | ||||
| 		return rp.stateServer | ||||
| 	} else if !rp.remotePublic { | ||||
| 		return rp.stateClient | ||||
| 	} | ||||
|  | ||||
| 	if rp.localIP < rp.peer.PeerIP { | ||||
| 		return rp.stateClient | ||||
| 	} | ||||
| 	return rp.stateServer | ||||
| } | ||||
|  | ||||
| // The remote is a server. | ||||
| func (rp *peerSuper) stateServer() stateFunc { | ||||
| 	rp.logf("STATE: Server") | ||||
| 	rp.peerData.dataCipher = newDataCipher() | ||||
| 	rp.updateShared() | ||||
|  | ||||
| 	var ( | ||||
| 		pingTimer = time.NewTimer(pingInterval) | ||||
| 		ping      = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} | ||||
| 	) | ||||
| 	defer pingTimer.Stop() | ||||
|  | ||||
| 	ping.SentAt = time.Now().UnixMilli() | ||||
| 	rp.sendControlPacket(ping) | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-pingTimer.C: | ||||
| 			ping.SentAt = time.Now().UnixMilli() | ||||
| 			rp.sendControlPacket(ping) | ||||
| 			pingTimer.Reset(pingInterval) | ||||
|  | ||||
| 		case <-rp.controlPackets: | ||||
| 			// Ignore | ||||
|  | ||||
| 		case rp.peer = <-rp.peerUpdates: | ||||
| 			return rp.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // The remote is a client. | ||||
| func (rp *peerSuper) stateClient() stateFunc { | ||||
| 	rp.logf("STATE: Client") | ||||
| 	rp.updateShared() | ||||
|  | ||||
| 	// TODO: Could use timeout to set dataCipher to nil. | ||||
| 	var currentKey = [32]byte{} | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case cPkt := <-rp.controlPackets: | ||||
| 			if cPkt.RemoteAddr != rp.peerData.remoteAddr { | ||||
| 				rp.peerData.remoteAddr = cPkt.RemoteAddr | ||||
| 				rp.logf("Got new remote address: %v", cPkt.RemoteAddr) | ||||
| 				rp.updateShared() | ||||
| 			} | ||||
|  | ||||
| 			ping, ok := cPkt.Payload.(pingPacket) | ||||
| 			if !ok { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if ping.SharedKey != currentKey { | ||||
| 				rp.logf("Connected with new shared key") | ||||
| 				currentKey = ping.SharedKey | ||||
| 				rp.peerData.dataCipher = newDataCipherFromKey(currentKey) | ||||
| 				rp.updateShared() | ||||
| 			} | ||||
|  | ||||
| 			rp.sendControlPacket(newPongPacket(ping.SentAt)) | ||||
|  | ||||
| 		case rp.peer = <-rp.peerUpdates: | ||||
| 			return rp.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *peerSuper) updateShared() { | ||||
| 	data := rp.peerData | ||||
| 	rp.shared.Store(&data) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) { | ||||
| 	buf := pkt.Marshal(rp.pktBuf) | ||||
| 	h := xHeader{ | ||||
| 		StreamID: controlStreamID, | ||||
| 		Counter:  atomic.AddUint64(&rp.counter, 1), | ||||
| 		SourceIP: rp.localIP, | ||||
| 		DestIP:   rp.remoteIP, | ||||
| 	} | ||||
| 	buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) | ||||
| 	rp.conn.WriteTo(buf, rp.peerData.remoteAddr) | ||||
| } | ||||
							
								
								
									
										205
									
								
								node/peer.go
									
									
									
									
									
								
							
							
						
						
									
										205
									
								
								node/peer.go
									
									
									
									
									
								
							| @@ -1 +1,206 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type remotePeers [256]*remotePeer | ||||
|  | ||||
| type peerData struct { | ||||
| 	controlCipher *controlCipher | ||||
| 	dataCipher    *dataCipher | ||||
| 	remoteAddr    netip.AddrPort | ||||
| } | ||||
|  | ||||
| type remotePeer struct { | ||||
| 	// Immutable data. | ||||
| 	localIP     byte | ||||
| 	remoteIP    byte | ||||
| 	privKey     []byte | ||||
| 	localPublic bool // True if local node is public. | ||||
| 	iface       *ifWriter | ||||
| 	conn        *connWriter2 | ||||
|  | ||||
| 	// Shared state. | ||||
| 	shared *atomic.Pointer[peerData] | ||||
|  | ||||
| 	// Only used in HandlePeerUpdate. | ||||
| 	peerVersion int64 | ||||
|  | ||||
| 	// Only used in HandlePacket / Not synchronized. | ||||
| 	dupCheck   *dupCheck | ||||
| 	decryptBuf []byte | ||||
|  | ||||
| 	// Only used in SendData / Not synchronized. | ||||
| 	encryptBuf []byte | ||||
|  | ||||
| 	// Used for sending control and data packets. Atomic access only. | ||||
| 	counter uint64 | ||||
|  | ||||
| 	// For communicating with the supervisor thread. | ||||
| 	peerUpdates    chan *m.Peer | ||||
| 	controlPackets chan controlPacket | ||||
| } | ||||
|  | ||||
| func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter2) *remotePeer { | ||||
| 	rp := &remotePeer{ | ||||
| 		localIP:        conf.PeerIP, | ||||
| 		remoteIP:       remoteIP, | ||||
| 		privKey:        conf.EncPrivKey, | ||||
| 		localPublic:    addrIsValid(conf.PublicIP), | ||||
| 		iface:          iface, | ||||
| 		conn:           conn, | ||||
| 		shared:         &atomic.Pointer[peerData]{}, | ||||
| 		dupCheck:       newDupCheck(0), | ||||
| 		decryptBuf:     make([]byte, bufferSize), | ||||
| 		encryptBuf:     make([]byte, bufferSize), | ||||
| 		counter:        uint64(time.Now().Unix()) << 30, | ||||
| 		peerUpdates:    make(chan *m.Peer), | ||||
| 		controlPackets: make(chan controlPacket, 512), | ||||
| 	} | ||||
|  | ||||
| 	pd := peerData{} | ||||
| 	rp.shared.Store(&pd) | ||||
|  | ||||
| 	go newPeerSuper(rp).Run() | ||||
|  | ||||
| 	return rp | ||||
| } | ||||
|  | ||||
| func (rp *remotePeer) logf(msg string, args ...any) { | ||||
| 	log.Printf(fmt.Sprintf("[%03d] ", rp.remoteIP)+msg, args...) | ||||
| } | ||||
|  | ||||
| func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) { | ||||
| 	if peer != nil && peer.Version != rp.peerVersion { | ||||
| 		rp.peerUpdates <- peer | ||||
| 		rp.peerVersion = peer.Version | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // HandlePacket accepts a raw data packet coming in from the network. | ||||
| // | ||||
| // This function is called by a single thread. | ||||
| func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h xHeader, data []byte) { | ||||
| 	switch h.StreamID { | ||||
| 	case controlStreamID: | ||||
| 		rp.handleControlPacket(addr, h, data) | ||||
|  | ||||
| 	case dataStreamID: | ||||
| 		rp.handleDataPacket(data) | ||||
|  | ||||
| 	case forwardStreamID: | ||||
| 		fallthrough | ||||
| 		// TODO | ||||
| 		//rp.handleForwardPacket(h, data) | ||||
| 	default: | ||||
| 		rp.logf("Unknown stream ID: %d", h.StreamID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h xHeader, data []byte) { | ||||
| 	shared := rp.shared.Load() | ||||
| 	if shared.controlCipher == nil { | ||||
| 		rp.logf("Not connected (control).") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) | ||||
| 	if !ok { | ||||
| 		rp.logf("Failed to decrypt control packet.") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if len(out) == 0 { | ||||
| 		rp.logf("Empty control packet from: %d", h.SourceIP) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if rp.dupCheck.IsDup(h.Counter) { | ||||
| 		rp.logf("Duplicate control packet: %d", h.Counter) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if h.DestIP != rp.localIP { | ||||
| 		// TODO: Forward control packet. | ||||
| 		// TODO: Probably this should be dropped. | ||||
| 		// Control packets should be forwarded as data for efficiency. | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	pkt := controlPacket{ | ||||
| 		SrcIP:      h.SourceIP, | ||||
| 		RemoteAddr: addr, | ||||
| 	} | ||||
|  | ||||
| 	var err error | ||||
|  | ||||
| 	switch out[0] { | ||||
| 	case packetTypePing: | ||||
| 		pkt.Payload, err = parsePingPacket(out) | ||||
| 	case packetTypePong: | ||||
| 		pkt.Payload, err = parsePongPacket(out) | ||||
| 	default: | ||||
| 		rp.logf("Unknown control packet type: %d", out[0]) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		rp.logf("Failed to parse control packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	select { | ||||
| 	case rp.controlPackets <- pkt: | ||||
| 	default: | ||||
| 		rp.logf("Dropping control packet.") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (rp *remotePeer) handleDataPacket(data []byte) { | ||||
| 	shared := rp.shared.Load() | ||||
| 	if shared.dataCipher == nil { | ||||
| 		rp.logf("Not connected (recv).") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf) | ||||
| 	if !ok { | ||||
| 		rp.logf("Failed to decrypt data packet.") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	rp.iface.Write(dec) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| // SendData sends data coming from the interface going to the network. | ||||
| // | ||||
| // This function is called by a single thread. | ||||
| func (rp *remotePeer) SendData(data []byte) { | ||||
| 	shared := rp.shared.Load() | ||||
| 	if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { | ||||
| 		rp.logf("Not connected (send).") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	h := xHeader{ | ||||
| 		StreamID: dataStreamID, | ||||
| 		Counter:  atomic.AddUint64(&rp.counter, 1), | ||||
| 		SourceIP: rp.localIP, | ||||
| 		DestIP:   rp.remoteIP, | ||||
| 	} | ||||
|  | ||||
| 	enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) | ||||
| 	rp.conn.WriteTo(enc, shared.remoteAddr) | ||||
| } | ||||
|   | ||||
| @@ -8,12 +8,6 @@ import ( | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	connectTimeout  = 6 * time.Second | ||||
| 	pingInterval    = 6 * time.Second | ||||
| 	timeoutInterval = 20 * time.Second | ||||
| ) | ||||
|  | ||||
| type routingPacketWrapper struct { | ||||
| 	routingPacket | ||||
| 	Addr netip.AddrPort // Source. | ||||
| @@ -113,8 +107,6 @@ func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) { | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateFunc func() stateFunc | ||||
|  | ||||
| func (s *peerSupervisor) stateInit() stateFunc { | ||||
| 	if s.peer == nil { | ||||
| 		return s.stateDisconnected | ||||
| @@ -316,12 +308,12 @@ func (s *peerSupervisor) updateRoutingTable(up bool) { | ||||
| func (s *peerSupervisor) sendPing() uint64 { | ||||
| 	traceID := newTraceID() | ||||
| 	pkt := newRoutingPacket(packetTypePing, traceID) | ||||
| 	s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) | ||||
| 	s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf)) | ||||
| 	s.pingTimer.Reset(pingInterval) | ||||
| 	return traceID | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) sendPong(traceID uint64) { | ||||
| 	pkt := newRoutingPacket(packetTypePong, traceID) | ||||
| 	s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) | ||||
| 	s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf)) | ||||
| } | ||||
|   | ||||
| @@ -19,7 +19,7 @@ type peer struct { | ||||
| 	Up            bool           // No data will be sent to peers that are down. | ||||
| 	Addr          netip.AddrPort // If we have direct connection, otherwise use mediator. | ||||
| 	Mediator      bool           // True if the peer will mediate. | ||||
| 	RoutingCipher routingCipher | ||||
| 	RoutingCipher controlCipher | ||||
| 	DataCipher    dataCipher | ||||
|  | ||||
| 	// TODO: Deprecated below. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user