refactor-for-testability #3
| @@ -2,11 +2,10 @@ | |||||||
|  |  | ||||||
| ## Refactoring for Testability | ## Refactoring for Testability | ||||||
|  |  | ||||||
| * [ ] connWriter | * [x] connWriter | ||||||
|   * [ ] Separate send/relay calls |  | ||||||
| * [x] mcWriter | * [x] mcWriter | ||||||
| * [x] ifWriter | * [x] ifWriter | ||||||
| * [ ] ifReader | * [ ] ifReader (testing) | ||||||
| * [ ] connReader | * [ ] connReader | ||||||
| * [ ] mcReader | * [ ] mcReader | ||||||
| * [ ] hubPoller | * [ ] hubPoller | ||||||
|   | |||||||
| @@ -68,18 +68,18 @@ func newConnWriter(conn udpAddrPortWriter, localIP byte) *connWriter { | |||||||
|  |  | ||||||
| // Not safe for concurrent use. Should only be called by supervisor. | // Not safe for concurrent use. Should only be called by supervisor. | ||||||
| func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { | func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { | ||||||
| 	buf := pkt.Marshal(w.cBuf1) | 	buf := w.encryptControlPacket(pkt, route) | ||||||
| 	h := header{ |  | ||||||
| 		StreamID: controlStreamID, |  | ||||||
| 		Counter:  atomic.AddUint64(&w.counters[route.IP], 1), |  | ||||||
| 		SourceIP: w.localIP, |  | ||||||
| 		DestIP:   route.IP, |  | ||||||
| 	} |  | ||||||
| 	buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) |  | ||||||
| 	w.writeTo(buf, route.RemoteAddr) | 	w.writeTo(buf, route.RemoteAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Relay control packet. Routes must not be nil. | ||||||
| func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { | func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { | ||||||
|  | 	buf := w.encryptControlPacket(pkt, route) | ||||||
|  | 	w.relayPacket(buf, w.cBuf1, route, relay) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Encrypted packet will occupy cBuf2. | ||||||
|  | func (w *connWriter) encryptControlPacket(pkt marshaller, route *peerRoute) []byte { | ||||||
| 	buf := pkt.Marshal(w.cBuf1) | 	buf := pkt.Marshal(w.cBuf1) | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: controlStreamID, | 		StreamID: controlStreamID, | ||||||
| @@ -87,12 +87,11 @@ func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) | |||||||
| 		SourceIP: w.localIP, | 		SourceIP: w.localIP, | ||||||
| 		DestIP:   route.IP, | 		DestIP:   route.IP, | ||||||
| 	} | 	} | ||||||
| 	buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2) | 	return route.ControlCipher.Encrypt(h, buf, w.cBuf2) | ||||||
| 	w.relayPacket(buf, w.cBuf1, route, relay) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Not safe for concurrent use. Should only be called by ifReader. | // Not safe for concurrent use. Should only be called by ifReader. | ||||||
| func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { | func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: dataStreamID, | 		StreamID: dataStreamID, | ||||||
| 		Counter:  atomic.AddUint64(&w.counters[route.IP], 1), | 		Counter:  atomic.AddUint64(&w.counters[route.IP], 1), | ||||||
| @@ -101,17 +100,22 @@ func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) | 	enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) | ||||||
|  |  | ||||||
| 	if route.Direct { |  | ||||||
| 	w.writeTo(enc, route.RemoteAddr) | 	w.writeTo(enc, route.RemoteAddr) | ||||||
| 		return |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Relay a data packet. Routes must not be nil. | ||||||
|  | func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { | ||||||
|  | 	h := header{ | ||||||
|  | 		StreamID: dataStreamID, | ||||||
|  | 		Counter:  atomic.AddUint64(&w.counters[route.IP], 1), | ||||||
|  | 		SourceIP: w.localIP, | ||||||
|  | 		DestIP:   route.IP, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1) | ||||||
| 	w.relayPacket(enc, w.dBuf2, route, relay) | 	w.relayPacket(enc, w.dBuf2, route, relay) | ||||||
| } | } | ||||||
|  |  | ||||||
| // TODO: RelayDataPacket |  | ||||||
|  |  | ||||||
| // Safe for concurrent use. Should only be called by connReader. | // Safe for concurrent use. Should only be called by connReader. | ||||||
| // | // | ||||||
| // This function will send pkt to the peer directly. This is used when a peer | // This function will send pkt to the peer directly. This is used when a peer | ||||||
| @@ -122,10 +126,6 @@ func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) { | func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) { | ||||||
| 	if relay == nil || !relay.Up { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: dataStreamID, | 		StreamID: dataStreamID, | ||||||
| 		Counter:  atomic.AddUint64(&w.counters[relay.IP], 1), | 		Counter:  atomic.AddUint64(&w.counters[relay.IP], 1), | ||||||
|   | |||||||
| @@ -126,7 +126,7 @@ func TestConnWriter_SendControlPacket_direct(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Testing if we can relay a packet via an intermediary. | // Testing if we can relay a packet via an intermediary. | ||||||
| func TestConnWriter_SendControlPacket_relay(t *testing.T) { | func TestConnWriter_RelayControlPacket_relay(t *testing.T) { | ||||||
| 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} | 	writer := &testUDPAddrPortWriter{} | ||||||
| @@ -159,40 +159,6 @@ func TestConnWriter_SendControlPacket_relay(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Testing that a nil relay doesn't cause an issue. |  | ||||||
| func TestConnWriter_SendControlPacket_relay_relayNil(t *testing.T) { |  | ||||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := testPacket("hello world!") |  | ||||||
|  |  | ||||||
| 	w.RelayControlPacket(in, route, nil) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 0 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we don't send anything if the relay isn't up. |  | ||||||
| func TestConnWriter_SendControlPacket_relay_relayNotUp(t *testing.T) { |  | ||||||
| 	route, rRoute, relay, _ := testConnWriter_getTestRoutes() |  | ||||||
| 	relay.Up = false |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := testPacket("hello world!") |  | ||||||
|  |  | ||||||
| 	w.RelayControlPacket(in, route, relay) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 0 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Testing that we can send a data packet directly to a remote route. | // Testing that we can send a data packet directly to a remote route. | ||||||
| func TestConnWriter_SendDataPacket_direct(t *testing.T) { | func TestConnWriter_SendDataPacket_direct(t *testing.T) { | ||||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||||
| @@ -202,7 +168,7 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { | |||||||
| 	w := newConnWriter(writer, rRoute.IP) | 	w := newConnWriter(writer, rRoute.IP) | ||||||
|  |  | ||||||
| 	in := []byte("hello world!") | 	in := []byte("hello world!") | ||||||
| 	w.SendDataPacket(in, route, nil) | 	w.SendDataPacket(in, route) | ||||||
|  |  | ||||||
| 	out := writer.Written() | 	out := writer.Written() | ||||||
| 	if len(out) != 1 { | 	if len(out) != 1 { | ||||||
| @@ -224,14 +190,14 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Testing that we can relay a data packet via a relay. | // Testing that we can relay a data packet via a relay. | ||||||
| func TestConnWriter_SendDataPacket_relay(t *testing.T) { | func TestConnWriter_RelayDataPacket_relay(t *testing.T) { | ||||||
| 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} | 	writer := &testUDPAddrPortWriter{} | ||||||
| 	w := newConnWriter(writer, rRoute.IP) | 	w := newConnWriter(writer, rRoute.IP) | ||||||
| 	in := []byte("Hello world!") | 	in := []byte("Hello world!") | ||||||
|  |  | ||||||
| 	w.SendDataPacket(in, route, relay) | 	w.RelayDataPacket(in, route, relay) | ||||||
|  |  | ||||||
| 	out := writer.Written() | 	out := writer.Written() | ||||||
| 	if len(out) != 1 { | 	if len(out) != 1 { | ||||||
| @@ -257,35 +223,26 @@ func TestConnWriter_SendDataPacket_relay(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Testing that we don't attempt to relay if the relay is nil. | // Testing that we can send an already encrypted packet. | ||||||
| func TestConnWriter_SendDataPacket_relay_relayNil(t *testing.T) { | func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { | ||||||
| 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} | 	writer := &testUDPAddrPortWriter{} | ||||||
| 	w := newConnWriter(writer, rRoute.IP) | 	w := newConnWriter(writer, rRoute.IP) | ||||||
| 	in := []byte("Hello world!") | 	in := []byte("Hello world!") | ||||||
|  |  | ||||||
| 	w.SendDataPacket(in, route, nil) | 	w.SendEncryptedDataPacket(in, route) | ||||||
|  |  | ||||||
| 	out := writer.Written() | 	out := writer.Written() | ||||||
| 	if len(out) != 0 { | 	if len(out) != 1 { | ||||||
| 		t.Fatal(out) | 		t.Fatal(out) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if out[0].Addr != route.RemoteAddr { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| // Testing that we don't attempt to relay if the relay isn't up. | 	if !bytes.Equal(out[0].Data, in) { | ||||||
| func TestConnWriter_SendDataPacket_relay_relayNotUp(t *testing.T) { | 		t.Fatal(out[0]) | ||||||
| 	route, rRoute, relay, _ := testConnWriter_getTestRoutes() |  | ||||||
| 	relay.Up = false |  | ||||||
|  |  | ||||||
| 	writer := &testUDPAddrPortWriter{} |  | ||||||
| 	w := newConnWriter(writer, rRoute.IP) |  | ||||||
| 	in := []byte("Hello world!") |  | ||||||
|  |  | ||||||
| 	w.SendDataPacket(in, route, relay) |  | ||||||
|  |  | ||||||
| 	out := writer.Written() |  | ||||||
| 	if len(out) != 0 { |  | ||||||
| 		t.Fatal(out) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -38,14 +38,14 @@ func (dc *dupCheck) IsDup(counter uint64) bool { | |||||||
| 	delta := counter - dc.tailCounter | 	delta := counter - dc.tailCounter | ||||||
|  |  | ||||||
| 	// Full clear. | 	// Full clear. | ||||||
| 	if delta >= bitSetSize { | 	if delta >= bitSetSize-1 { | ||||||
| 		dc.ClearAll() | 		dc.ClearAll() | ||||||
| 		dc.Set(0) | 		dc.Set(0) | ||||||
|  |  | ||||||
| 		dc.tail = 1 | 		dc.tail = 1 | ||||||
| 		dc.head = 2 | 		dc.head = 2 | ||||||
| 		dc.tailCounter = counter + 1 | 		dc.tailCounter = counter + 1 | ||||||
| 		dc.headCounter = dc.tailCounter - bitSetSize | 		dc.headCounter = dc.tailCounter - bitSetSize + 1 | ||||||
|  |  | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -20,6 +20,18 @@ type header struct { | |||||||
| 	Counter  uint64 // Init with time.Now().Unix << 30 to ensure monotonic. | 	Counter  uint64 // Init with time.Now().Unix << 30 to ensure monotonic. | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func parseHeader(b []byte) (h header, ok bool) { | ||||||
|  | 	if len(b) < headerSize { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	h.Version = b[0] | ||||||
|  | 	h.StreamID = b[1] | ||||||
|  | 	h.SourceIP = b[2] | ||||||
|  | 	h.DestIP = b[3] | ||||||
|  | 	h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) | ||||||
|  | 	return h, true | ||||||
|  | } | ||||||
|  |  | ||||||
| func (h *header) Parse(b []byte) { | func (h *header) Parse(b []byte) { | ||||||
| 	h.Version = b[0] | 	h.Version = b[0] | ||||||
| 	h.StreamID = b[1] | 	h.StreamID = b[1] | ||||||
|   | |||||||
| @@ -57,7 +57,7 @@ func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if relay := r.relay.Load(); relay != nil { | 	if relay := r.relay.Load(); relay != nil && relay.Up { | ||||||
| 		r.relayDataPacket(pkt, route, relay) | 		r.relayDataPacket(pkt, route, relay) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								peer/bitset.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								peer/bitset.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | const bitSetSize = 512 // Multiple of 64. | ||||||
|  |  | ||||||
|  | type bitSet [bitSetSize / 64]uint64 | ||||||
|  |  | ||||||
|  | func (bs *bitSet) Set(i int) { | ||||||
|  | 	bs[i/64] |= 1 << (i % 64) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (bs *bitSet) Clear(i int) { | ||||||
|  | 	bs[i/64] &= ^(1 << (i % 64)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (bs *bitSet) ClearAll() { | ||||||
|  | 	clear(bs[:]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (bs *bitSet) Get(i int) bool { | ||||||
|  | 	return bs[i/64]&(1<<(i%64)) != 0 | ||||||
|  | } | ||||||
							
								
								
									
										48
									
								
								peer/bitset_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								peer/bitset_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"math/rand" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestBitSet(t *testing.T) { | ||||||
|  | 	state := make([]bool, bitSetSize) | ||||||
|  | 	for i := range state { | ||||||
|  | 		state[i] = rand.Float32() > 0.5 | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	bs := bitSet{} | ||||||
|  |  | ||||||
|  | 	for i := range state { | ||||||
|  | 		if state[i] { | ||||||
|  | 			bs.Set(i) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range state { | ||||||
|  | 		if bs.Get(i) != state[i] { | ||||||
|  | 			t.Fatal(i, state[i], bs.Get(i)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range state { | ||||||
|  | 		if rand.Float32() > 0.5 { | ||||||
|  | 			state[i] = false | ||||||
|  | 			bs.Clear(i) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range state { | ||||||
|  | 		if bs.Get(i) != state[i] { | ||||||
|  | 			t.Fatal(i, state[i], bs.Get(i)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	bs.ClearAll() | ||||||
|  |  | ||||||
|  | 	for i := range state { | ||||||
|  | 		if bs.Get(i) { | ||||||
|  | 			t.Fatal(i, bs.Get(i)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										26
									
								
								peer/cipher-control.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								peer/cipher-control.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "golang.org/x/crypto/nacl/box" | ||||||
|  |  | ||||||
|  | type controlCipher struct { | ||||||
|  | 	sharedKey [32]byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newControlCipher(privKey, pubKey []byte) *controlCipher { | ||||||
|  | 	shared := [32]byte{} | ||||||
|  | 	box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey)) | ||||||
|  | 	return &controlCipher{shared} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte { | ||||||
|  | 	const s = controlHeaderSize | ||||||
|  | 	out = out[:s+controlCipherOverhead+len(data)] | ||||||
|  | 	h.Marshal(out[:s]) | ||||||
|  | 	box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { | ||||||
|  | 	const s = controlHeaderSize | ||||||
|  | 	return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey) | ||||||
|  | } | ||||||
							
								
								
									
										122
									
								
								peer/cipher-control_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								peer/cipher-control_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,122 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/nacl/box" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func newControlCipherForTesting() (c1, c2 *controlCipher) { | ||||||
|  | 	pubKey1, privKey1, err := box.GenerateKey(rand.Reader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	pubKey2, privKey2, err := box.GenerateKey(rand.Reader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return newControlCipher(privKey1[:], pubKey2[:]), | ||||||
|  | 		newControlCipher(privKey2[:], pubKey1[:]) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestControlCipher(t *testing.T) { | ||||||
|  | 	c1, c2 := newControlCipherForTesting() | ||||||
|  |  | ||||||
|  | 	maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) | ||||||
|  | 	rand.Read(maxSizePlaintext) | ||||||
|  |  | ||||||
|  | 	testCases := [][]byte{ | ||||||
|  | 		make([]byte, 0), | ||||||
|  | 		{1}, | ||||||
|  | 		{255}, | ||||||
|  | 		{1, 2, 3, 4, 5}, | ||||||
|  | 		[]byte("Hello world"), | ||||||
|  | 		maxSizePlaintext, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, plaintext := range testCases { | ||||||
|  | 		h1 := header{ | ||||||
|  | 			StreamID: controlStreamID, | ||||||
|  | 			Counter:  235153, | ||||||
|  | 			SourceIP: 4, | ||||||
|  | 			DestIP:   88, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 		encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 		h2 := header{} | ||||||
|  | 		h2.Parse(encrypted) | ||||||
|  | 		if !reflect.DeepEqual(h1, h2) { | ||||||
|  | 			t.Fatal(h1, h2) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) | ||||||
|  | 		if !ok { | ||||||
|  | 			t.Fatal(ok) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !bytes.Equal(decrypted, plaintext) { | ||||||
|  | 			t.Fatal("not equal") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestControlCipher_ShortCiphertext(t *testing.T) { | ||||||
|  | 	c1, _ := newControlCipherForTesting() | ||||||
|  | 	shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1) | ||||||
|  | 	rand.Read(shortText) | ||||||
|  | 	_, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkControlCipher_Encrypt(b *testing.B) { | ||||||
|  | 	c1, _ := newControlCipherForTesting() | ||||||
|  | 	h1 := header{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkControlCipher_Decrypt(b *testing.B) { | ||||||
|  | 	c1, c2 := newControlCipherForTesting() | ||||||
|  |  | ||||||
|  | 	h1 := header{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	encrypted = c1.Encrypt(h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 	decrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		decrypted, _ = c2.Decrypt(encrypted, decrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										61
									
								
								peer/cipher-data.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								peer/cipher-data.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/aes" | ||||||
|  | 	"crypto/cipher" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"log" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type dataCipher struct { | ||||||
|  | 	key  [32]byte | ||||||
|  | 	aead cipher.AEAD | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newDataCipher() *dataCipher { | ||||||
|  | 	key := [32]byte{} | ||||||
|  | 	if _, err := rand.Read(key[:]); err != nil { | ||||||
|  | 		log.Fatalf("Failed to read random data: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return newDataCipherFromKey(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newDataCipherFromKey(key [32]byte) *dataCipher { | ||||||
|  | 	block, err := aes.NewCipher(key[:]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to create new cipher: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	aead, err := cipher.NewGCM(block) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to create new GCM: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &dataCipher{key: key, aead: aead} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Key() [32]byte { | ||||||
|  | 	return sc.key | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte { | ||||||
|  | 	const s = dataHeaderSize | ||||||
|  | 	out = out[:s+dataCipherOverhead+len(data)] | ||||||
|  | 	h.Marshal(out[:s]) | ||||||
|  | 	sc.aead.Seal(out[s:s], out[:s], data, nil) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) { | ||||||
|  | 	const s = dataHeaderSize | ||||||
|  | 	if len(encrypted) < s+dataCipherOverhead { | ||||||
|  | 		ok = false | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil) | ||||||
|  | 	ok = err == nil | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										141
									
								
								peer/cipher-data_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								peer/cipher-data_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	mrand "math/rand/v2" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDataCipher(t *testing.T) { | ||||||
|  | 	maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(maxSizePlaintext) | ||||||
|  |  | ||||||
|  | 	testCases := [][]byte{ | ||||||
|  | 		make([]byte, 0), | ||||||
|  | 		{1}, | ||||||
|  | 		{255}, | ||||||
|  | 		{1, 2, 3, 4, 5}, | ||||||
|  | 		[]byte("Hello world"), | ||||||
|  | 		maxSizePlaintext, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, plaintext := range testCases { | ||||||
|  | 		h1 := header{ | ||||||
|  | 			StreamID: dataStreamID, | ||||||
|  | 			Counter:  235153, | ||||||
|  | 			SourceIP: 4, | ||||||
|  | 			DestIP:   88, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 		dc1 := newDataCipher() | ||||||
|  | 		encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
|  | 		h2 := header{} | ||||||
|  | 		h2.Parse(encrypted) | ||||||
|  |  | ||||||
|  | 		dc2 := newDataCipherFromKey(dc1.Key()) | ||||||
|  |  | ||||||
|  | 		decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | ||||||
|  | 		if !ok { | ||||||
|  | 			t.Fatal(ok) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !bytes.Equal(plaintext, decrypted) { | ||||||
|  | 			t.Fatal("not equal") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !reflect.DeepEqual(h1, h2) { | ||||||
|  | 			t.Fatalf("%v != %v", h1, h2) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDataCipher_ModifyCiphertext(t *testing.T) { | ||||||
|  | 	maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(maxSizePlaintext) | ||||||
|  |  | ||||||
|  | 	testCases := [][]byte{ | ||||||
|  | 		make([]byte, 0), | ||||||
|  | 		{1}, | ||||||
|  | 		{255}, | ||||||
|  | 		{1, 2, 3, 4, 5}, | ||||||
|  | 		[]byte("Hello world"), | ||||||
|  | 		maxSizePlaintext, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, plaintext := range testCases { | ||||||
|  | 		h1 := header{ | ||||||
|  | 			Counter:  235153, | ||||||
|  | 			SourceIP: 4, | ||||||
|  | 			DestIP:   88, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 		dc1 := newDataCipher() | ||||||
|  | 		encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
|  | 		encrypted[mrand.IntN(len(encrypted))]++ | ||||||
|  |  | ||||||
|  | 		dc2 := newDataCipherFromKey(dc1.Key()) | ||||||
|  |  | ||||||
|  | 		_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize)) | ||||||
|  | 		if ok { | ||||||
|  | 			t.Fatal(ok) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDataCipher_ShortCiphertext(t *testing.T) { | ||||||
|  | 	dc1 := newDataCipher() | ||||||
|  | 	shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1) | ||||||
|  | 	rand.Read(shortText) | ||||||
|  | 	_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize)) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkDataCipher_Encrypt(b *testing.B) { | ||||||
|  | 	h1 := header{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	dc1 := newDataCipher() | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkDataCipher_Decrypt(b *testing.B) { | ||||||
|  | 	h1 := header{ | ||||||
|  | 		Counter:  235153, | ||||||
|  | 		SourceIP: 4, | ||||||
|  | 		DestIP:   88, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead) | ||||||
|  | 	rand.Read(plaintext) | ||||||
|  |  | ||||||
|  | 	encrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	dc1 := newDataCipher() | ||||||
|  | 	encrypted = dc1.Encrypt(h1, plaintext, encrypted) | ||||||
|  |  | ||||||
|  | 	decrypted := make([]byte, bufferSize) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		decrypted, _ = dc1.Decrypt(encrypted, decrypted) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										13
									
								
								peer/cipher-discovery.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								peer/cipher-discovery.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | func signData(privKey *[64]byte, h header, data, out []byte) []byte { | ||||||
|  | 	out = out[:headerSize] | ||||||
|  | 	h.Marshal(out) | ||||||
|  | 	return sign.Sign(out, data, privKey) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) { | ||||||
|  | 	return sign.Open(out[:0], signed[headerSize:], pubKey) | ||||||
|  | } | ||||||
|  | */ | ||||||
							
								
								
									
										141
									
								
								peer/connreader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								peer/connreader.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type connReader struct { | ||||||
|  | 	conn    udpReader | ||||||
|  | 	iface   ifWriter | ||||||
|  | 	sender  encryptedPacketSender | ||||||
|  | 	super   controlMsgHandler | ||||||
|  | 	localIP byte | ||||||
|  | 	routes  [256]*atomic.Pointer[peerRoute] | ||||||
|  |  | ||||||
|  | 	buf       []byte | ||||||
|  | 	decBuf    []byte | ||||||
|  | 	dupChecks [256]*dupCheck | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newConnReader( | ||||||
|  | 	conn udpReader, | ||||||
|  | 	ifWriter ifWriter, | ||||||
|  | 	sender encryptedPacketSender, | ||||||
|  | 	super controlMsgHandler, | ||||||
|  | 	localIP byte, | ||||||
|  | 	routes [256]*atomic.Pointer[peerRoute], | ||||||
|  | ) *connReader { | ||||||
|  | 	return &connReader{ | ||||||
|  | 		conn:    conn, | ||||||
|  | 		iface:   ifWriter, | ||||||
|  | 		sender:  sender, | ||||||
|  | 		super:   super, | ||||||
|  | 		localIP: localIP, | ||||||
|  | 		routes:  routes, | ||||||
|  | 		buf:     make([]byte, bufferSize), | ||||||
|  | 		decBuf:  make([]byte, bufferSize), | ||||||
|  | 		dupChecks: func() (out [256]*dupCheck) { | ||||||
|  | 			for i := range out { | ||||||
|  | 				out[i] = newDupCheck(0) | ||||||
|  | 			} | ||||||
|  | 			return | ||||||
|  | 		}(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *connReader) Run() { | ||||||
|  | 	for { | ||||||
|  | 		r.handleNextPacket() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *connReader) logf(s string, args ...any) { | ||||||
|  | 	log.Printf("[ConnReader] "+s, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *connReader) handleNextPacket() { | ||||||
|  | 	buf := r.buf[:bufferSize] | ||||||
|  | 	n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to read from UDP port: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if n < headerSize { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port()) | ||||||
|  |  | ||||||
|  | 	buf = buf[:n] | ||||||
|  | 	h, ok := parseHeader(buf) | ||||||
|  | 	if !ok { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	route := r.routes[h.SourceIP].Load() | ||||||
|  |  | ||||||
|  | 	switch h.StreamID { | ||||||
|  | 	case controlStreamID: | ||||||
|  | 		r.handleControlPacket(route, remoteAddr, h, buf) | ||||||
|  |  | ||||||
|  | 	case dataStreamID: | ||||||
|  | 		r.handleDataPacket(route, h, buf) | ||||||
|  |  | ||||||
|  | 	default: | ||||||
|  | 		r.logf("Unknown stream ID: %d", h.StreamID) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *connReader) handleControlPacket( | ||||||
|  | 	route *peerRoute, | ||||||
|  | 	addr netip.AddrPort, | ||||||
|  | 	h header, | ||||||
|  | 	enc []byte, | ||||||
|  | ) { | ||||||
|  | 	if route.ControlCipher == nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if h.DestIP != r.localIP { | ||||||
|  | 		r.logf("Incorrect destination IP on control packet: %d", h.DestIP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg, err := decryptControlPacket(route, addr, h, enc, r.decBuf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.logf("Failed to decrypt control packet: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r.super.HandleControlMsg(msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *connReader) handleDataPacket(route *peerRoute, h header, enc []byte) { | ||||||
|  | 	if !route.Up { | ||||||
|  | 		r.logf("Not connected (recv).") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	data, err := decryptDataPacket(route, h, enc, r.decBuf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.logf("Failed to decrypt data packet: %v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if h.DestIP == r.localIP { | ||||||
|  | 		if _, err := r.iface.Write(data); err != nil { | ||||||
|  | 			log.Fatalf("Failed to write to interface: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	destRoute := r.routes[h.DestIP].Load() | ||||||
|  | 	if !destRoute.Up { | ||||||
|  | 		r.logf("Not connected (relay): %d", destRoute.IP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r.sender.SendEncryptedDataPacket(data, destRoute) | ||||||
|  | } | ||||||
							
								
								
									
										318
									
								
								peer/connreader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										318
									
								
								peer/connreader_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,318 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | /* | ||||||
|  | type mockIfWriter struct { | ||||||
|  | 	Written [][]byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *mockIfWriter) Write(b []byte) (int, error) { | ||||||
|  | 	w.Written = append(w.Written, bytes.Clone(b)) | ||||||
|  | 	return len(b), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type mockEncryptedPacket struct { | ||||||
|  | 	Packet []byte | ||||||
|  | 	Route  *peerRoute | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type mockEncryptedPacketSender struct { | ||||||
|  | 	Sent []mockEncryptedPacket | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { | ||||||
|  | 	m.Sent = append(m.Sent, mockEncryptedPacket{ | ||||||
|  | 		Packet: bytes.Clone(pkt), | ||||||
|  | 		Route:  route, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type mockControlMsgHandler struct { | ||||||
|  | 	Messages []any | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *mockControlMsgHandler) HandleControlMsg(pkt any) { | ||||||
|  | 	m.Messages = append(m.Messages, pkt) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type udpPipe struct { | ||||||
|  | 	packets chan []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newUDPPipe() *udpPipe { | ||||||
|  | 	return &udpPipe{make(chan []byte, 1024)} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { | ||||||
|  | 	p.packets <- bytes.Clone(b) | ||||||
|  | 	return len(b), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { | ||||||
|  | 	packet := <-p.packets | ||||||
|  | 	copy(b, packet) | ||||||
|  | 	return len(packet), netip.AddrPort{}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type connReaderTestHarness struct { | ||||||
|  | 	Pipe         *udpPipe | ||||||
|  | 	R            *connReader | ||||||
|  | 	WRemote      *connWriter | ||||||
|  | 	WRelayRemote *connWriter | ||||||
|  | 	Remote       *peerRoute | ||||||
|  | 	RelayRemote  *peerRoute | ||||||
|  | 	IFace        *mockIfWriter | ||||||
|  | 	Sender       *mockEncryptedPacketSender | ||||||
|  | 	Super        *mockControlMsgHandler | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Peer 2 is indirect, peer 3 is direct. | ||||||
|  | func newConnReadeTestHarness() (h connReaderTestHarness) { | ||||||
|  | 	pipe := newUDPPipe() | ||||||
|  | 	routes := [256]*atomic.Pointer[peerRoute]{} | ||||||
|  | 	for i := range routes { | ||||||
|  | 		routes[i] = &atomic.Pointer[peerRoute]{} | ||||||
|  | 		routes[i].Store(&peerRoute{}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes() | ||||||
|  | 	routes[2].Store(local) | ||||||
|  | 	routes[3].Store(relayLocal) | ||||||
|  |  | ||||||
|  | 	h.Pipe = pipe | ||||||
|  | 	h.WRemote = newConnWriter(pipe, 2) | ||||||
|  | 	h.WRelayRemote = newConnWriter(pipe, 3) | ||||||
|  |  | ||||||
|  | 	h.Remote = remote | ||||||
|  | 	h.RelayRemote = relayRemote | ||||||
|  | 	h.IFace = &mockIfWriter{} | ||||||
|  | 	h.Sender = &mockEncryptedPacketSender{} | ||||||
|  | 	h.Super = &mockControlMsgHandler{} | ||||||
|  | 	h.R = newConnReader( | ||||||
|  | 		pipe, | ||||||
|  | 		h.IFace, | ||||||
|  | 		h.Sender, | ||||||
|  | 		h.Super, | ||||||
|  | 		1, | ||||||
|  | 		routes) | ||||||
|  | 	return h | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can receive a control packet. | ||||||
|  | func TestConnReader_handleControlPacket(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := synPacket{TraceID: 1234} | ||||||
|  |  | ||||||
|  | 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||||
|  |  | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(h.Super.Messages) != 1 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg := h.Super.Messages[0].(controlMsg[synPacket]) | ||||||
|  | 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||||
|  | 		t.Fatal(msg.Packet) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that a short packet is ignored. | ||||||
|  | func TestConnReader_handleNextPacket_short(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that a packet with an unexpected stream ID is ignored. | ||||||
|  | func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := synPacket{TraceID: 1234} | ||||||
|  |  | ||||||
|  | 	encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
|  | 	var header header | ||||||
|  | 	header.Parse(encrypted) | ||||||
|  | 	header.StreamID = 100 | ||||||
|  | 	header.Marshal(encrypted) | ||||||
|  |  | ||||||
|  | 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that control packet without matching control cipher is ignored. | ||||||
|  | func TestConnReader_handleControlPacket_noCipher(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := synPacket{TraceID: 1234} | ||||||
|  |  | ||||||
|  | 	encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
|  | 	var header header | ||||||
|  | 	header.Parse(encrypted) | ||||||
|  | 	header.SourceIP = 10 | ||||||
|  | 	header.Marshal(encrypted) | ||||||
|  |  | ||||||
|  | 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that control packet with incrrect destination IP is ignored. | ||||||
|  | func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := synPacket{TraceID: 1234} | ||||||
|  |  | ||||||
|  | 	encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
|  | 	var header header | ||||||
|  | 	header.Parse(encrypted) | ||||||
|  | 	header.DestIP++ | ||||||
|  | 	header.Marshal(encrypted) | ||||||
|  |  | ||||||
|  | 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that modified control packet is ignored. | ||||||
|  | func TestConnReader_handleControlPacket_modified(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := synPacket{TraceID: 1234} | ||||||
|  |  | ||||||
|  | 	encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
|  | 	encrypted[len(encrypted)-1]++ | ||||||
|  |  | ||||||
|  | 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type emptyPacket struct{} | ||||||
|  |  | ||||||
|  | func (p emptyPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return buf[:0] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that an empty control packet is ignored. | ||||||
|  | func TestConnReader_handleControlPacket_empty(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := emptyPacket{} | ||||||
|  |  | ||||||
|  | 	encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
|  | 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that a duplicate control packet is ignored. | ||||||
|  | func TestConnReader_handleControlPacket_duplicate(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := synPacket{TraceID: 1234} | ||||||
|  |  | ||||||
|  | 	log.Printf("%d", h.WRemote.counters[1]) | ||||||
|  | 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||||
|  | 	log.Printf("%d", h.WRemote.counters[1]) | ||||||
|  |  | ||||||
|  | 	// Rewind the counter. | ||||||
|  | 	h.WRemote.counters[1] = h.WRemote.counters[1] - 1 | ||||||
|  | 	log.Printf("%d", h.WRemote.counters[1]) | ||||||
|  | 	h.WRemote.SendControlPacket(pkt, h.Remote) | ||||||
|  |  | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(h.Super.Messages) != 1 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg := h.Super.Messages[0].(controlMsg[synPacket]) | ||||||
|  | 	if !reflect.DeepEqual(pkt, msg.Packet) { | ||||||
|  | 		t.Fatal(msg.Packet) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type invalidPacket struct { | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p invalidPacket) Marshal(b []byte) []byte { | ||||||
|  | 	out := b[:256] | ||||||
|  | 	clear(out) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that an invalid control packet is ignored (fails to parse). | ||||||
|  | func TestConnReader_handleControlPacket_cantParse(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := invalidPacket{} | ||||||
|  |  | ||||||
|  | 	encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) | ||||||
|  | 	h.WRemote.writeTo(encrypted, netip.AddrPort{}) | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  | 	if len(h.Super.Messages) != 0 { | ||||||
|  | 		t.Fatal(h.Super.Messages) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can receive a data packet. | ||||||
|  | func TestConnReader_handleDataPacket(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 1024) | ||||||
|  | 	rand.Read(pkt) | ||||||
|  |  | ||||||
|  | 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||||
|  |  | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(h.IFace.Written) != 1 { | ||||||
|  | 		t.Fatal(h.IFace.Written) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !bytes.Equal(pkt, h.IFace.Written[0]) { | ||||||
|  | 		t.Fatal(h.IFace.Written) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that data packet is ignored if route isn't up. | ||||||
|  | func TestConnReader_handleDataPacket_routeDown(t *testing.T) { | ||||||
|  | 	h := newConnReadeTestHarness() | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 1024) | ||||||
|  | 	rand.Read(pkt) | ||||||
|  |  | ||||||
|  | 	h.WRemote.SendDataPacket(pkt, h.Remote) | ||||||
|  | 	route := h.R.routes[2].Load() | ||||||
|  | 	route.Up = false | ||||||
|  |  | ||||||
|  | 	h.R.handleNextPacket() | ||||||
|  |  | ||||||
|  | 	if len(h.IFace.Written) != 0 { | ||||||
|  | 		t.Fatal(h.IFace.Written) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | */ | ||||||
|  | // Testing that a duplicate data packet is ignored. | ||||||
|  |  | ||||||
|  | // Testing that we send a relayed data packet. | ||||||
|  |  | ||||||
|  | // Testing that a relayed data packet is ignored if destination isn't up. | ||||||
							
								
								
									
										80
									
								
								peer/connwriter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								peer/connwriter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type connWriter struct { | ||||||
|  | 	localIP byte | ||||||
|  | 	conn    udpWriter | ||||||
|  |  | ||||||
|  | 	// For sending control packets. | ||||||
|  | 	cBuf1 []byte | ||||||
|  | 	cBuf2 []byte | ||||||
|  |  | ||||||
|  | 	// For sending data packets. | ||||||
|  | 	dBuf1 []byte | ||||||
|  | 	dBuf2 []byte | ||||||
|  |  | ||||||
|  | 	// Lock around for sending on UDP Conn. | ||||||
|  | 	wLock sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newConnWriter(conn udpWriter, localIP byte) *connWriter { | ||||||
|  | 	w := &connWriter{ | ||||||
|  | 		localIP: localIP, | ||||||
|  | 		conn:    conn, | ||||||
|  | 		cBuf1:   make([]byte, bufferSize), | ||||||
|  | 		cBuf2:   make([]byte, bufferSize), | ||||||
|  | 		dBuf1:   make([]byte, bufferSize), | ||||||
|  | 		dBuf2:   make([]byte, bufferSize), | ||||||
|  | 	} | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Not safe for concurrent use. Should only be called by supervisor. | ||||||
|  | func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) { | ||||||
|  | 	enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) | ||||||
|  | 	w.writeTo(enc, route.RemoteAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Relay control packet. Route must not be nil. | ||||||
|  | func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) { | ||||||
|  | 	enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2) | ||||||
|  | 	enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.cBuf1) | ||||||
|  | 	w.writeTo(enc, relay.RemoteAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Not safe for concurrent use. Should only be called by ifReader. | ||||||
|  | func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) { | ||||||
|  | 	enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) | ||||||
|  | 	w.writeTo(enc, route.RemoteAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Relay a data packet. Route must not be nil. | ||||||
|  | func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) { | ||||||
|  | 	enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1) | ||||||
|  | 	enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.dBuf2) | ||||||
|  | 	w.writeTo(enc, relay.RemoteAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Safe for concurrent use. Should only be called by connReader. | ||||||
|  | // | ||||||
|  | // This function will send pkt to the peer directly. This is used when a peer | ||||||
|  | // is acting as a relay and is forwarding already encrypted data for another | ||||||
|  | // peer. | ||||||
|  | func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) { | ||||||
|  | 	w.writeTo(pkt, route.RemoteAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) { | ||||||
|  | 	w.wLock.Lock() | ||||||
|  | 	if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil { | ||||||
|  | 		log.Printf("[ConnWriter] Failed to write to UDP port: %v", err) | ||||||
|  | 	} | ||||||
|  | 	w.wLock.Unlock() | ||||||
|  | } | ||||||
							
								
								
									
										240
									
								
								peer/connwriter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								peer/connwriter_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,240 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type testUDPPacket struct { | ||||||
|  | 	Addr netip.AddrPort | ||||||
|  | 	Data []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type testUDPAddrPortWriter struct { | ||||||
|  | 	written []testUDPPacket | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { | ||||||
|  | 	w.written = append(w.written, testUDPPacket{ | ||||||
|  | 		Addr: addr, | ||||||
|  | 		Data: bytes.Clone(b), | ||||||
|  | 	}) | ||||||
|  | 	return len(b), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *testUDPAddrPortWriter) Written() []testUDPPacket { | ||||||
|  | 	out := w.written | ||||||
|  | 	w.written = []testUDPPacket{} | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type testPacket string | ||||||
|  |  | ||||||
|  | func (p testPacket) Marshal(b []byte) []byte { | ||||||
|  | 	b = b[:len(p)] | ||||||
|  | 	copy(b, []byte(p)) | ||||||
|  | 	return b | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) { | ||||||
|  | 	localKeys := generateKeys() | ||||||
|  | 	remoteKeys := generateKeys() | ||||||
|  |  | ||||||
|  | 	local = newPeerRoute(2) | ||||||
|  | 	local.Up = true | ||||||
|  | 	local.Relay = false | ||||||
|  | 	local.PubSignKey = remoteKeys.PubSignKey | ||||||
|  | 	local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey) | ||||||
|  | 	local.DataCipher = newDataCipher() | ||||||
|  | 	local.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100) | ||||||
|  |  | ||||||
|  | 	remote = newPeerRoute(1) | ||||||
|  | 	remote.Up = true | ||||||
|  | 	remote.Relay = false | ||||||
|  | 	remote.PubSignKey = localKeys.PubSignKey | ||||||
|  | 	remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey) | ||||||
|  | 	remote.DataCipher = local.DataCipher | ||||||
|  | 	remote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) | ||||||
|  |  | ||||||
|  | 	rLocalKeys := generateKeys() | ||||||
|  | 	rRemoteKeys := generateKeys() | ||||||
|  |  | ||||||
|  | 	relayLocal = newPeerRoute(3) | ||||||
|  | 	relayLocal.Up = true | ||||||
|  | 	relayLocal.Relay = true | ||||||
|  | 	relayLocal.Direct = true | ||||||
|  | 	relayLocal.PubSignKey = rRemoteKeys.PubSignKey | ||||||
|  | 	relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey) | ||||||
|  | 	relayLocal.DataCipher = newDataCipher() | ||||||
|  | 	relayLocal.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100) | ||||||
|  |  | ||||||
|  | 	relayRemote = newPeerRoute(1) | ||||||
|  | 	relayRemote.Up = true | ||||||
|  | 	relayRemote.Relay = false | ||||||
|  | 	relayRemote.Direct = true | ||||||
|  | 	relayRemote.PubSignKey = rLocalKeys.PubSignKey | ||||||
|  | 	relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey) | ||||||
|  | 	relayRemote.DataCipher = relayLocal.DataCipher | ||||||
|  | 	relayRemote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100) | ||||||
|  |  | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // Testing if we can send a control packet directly to the remote route. | ||||||
|  | func TestConnWriter_SendControlPacket_direct(t *testing.T) { | ||||||
|  | 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||||
|  | 	route.Direct = true | ||||||
|  |  | ||||||
|  | 	writer := &testUDPAddrPortWriter{} | ||||||
|  | 	w := newConnWriter(writer, rRoute.IP) | ||||||
|  | 	in := testPacket("hello world!") | ||||||
|  |  | ||||||
|  | 	w.SendControlPacket(in, route) | ||||||
|  | 	out := writer.Written() | ||||||
|  | 	if len(out) != 1 { | ||||||
|  | 		t.Fatal(out) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if out[0].Addr != route.RemoteAddr { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | 	if string(dec) != string(in) { | ||||||
|  | 		t.Fatal(dec) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing if we can relay a packet via an intermediary. | ||||||
|  | func TestConnWriter_RelayControlPacket_relay(t *testing.T) { | ||||||
|  | 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | ||||||
|  |  | ||||||
|  | 	writer := &testUDPAddrPortWriter{} | ||||||
|  | 	w := newConnWriter(writer, rRoute.IP) | ||||||
|  | 	in := testPacket("hello world!") | ||||||
|  |  | ||||||
|  | 	w.RelayControlPacket(in, route, relay) | ||||||
|  |  | ||||||
|  | 	out := writer.Written() | ||||||
|  | 	if len(out) != 1 { | ||||||
|  | 		t.Fatal(out) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if out[0].Addr != relay.RemoteAddr { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024)) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if string(dec2) != string(in) { | ||||||
|  | 		t.Fatal(dec2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can send a data packet directly to a remote route. | ||||||
|  | func TestConnWriter_SendDataPacket_direct(t *testing.T) { | ||||||
|  | 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||||
|  | 	route.Direct = true | ||||||
|  |  | ||||||
|  | 	writer := &testUDPAddrPortWriter{} | ||||||
|  | 	w := newConnWriter(writer, rRoute.IP) | ||||||
|  |  | ||||||
|  | 	in := []byte("hello world!") | ||||||
|  | 	w.SendDataPacket(in, route) | ||||||
|  |  | ||||||
|  | 	out := writer.Written() | ||||||
|  | 	if len(out) != 1 { | ||||||
|  | 		t.Fatal(out) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if out[0].Addr != route.RemoteAddr { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !bytes.Equal(dec, in) { | ||||||
|  | 		t.Fatal(dec) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can relay a data packet via a relay. | ||||||
|  | func TestConnWriter_RelayDataPacket_relay(t *testing.T) { | ||||||
|  | 	route, rRoute, relay, rRelay := testConnWriter_getTestRoutes() | ||||||
|  |  | ||||||
|  | 	writer := &testUDPAddrPortWriter{} | ||||||
|  | 	w := newConnWriter(writer, rRoute.IP) | ||||||
|  | 	in := []byte("Hello world!") | ||||||
|  |  | ||||||
|  | 	w.RelayDataPacket(in, route, relay) | ||||||
|  |  | ||||||
|  | 	out := writer.Written() | ||||||
|  | 	if len(out) != 1 { | ||||||
|  | 		t.Fatal(out) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if out[0].Addr != relay.RemoteAddr { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024)) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024)) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !bytes.Equal(dec2, in) { | ||||||
|  | 		t.Fatal(dec2) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can send an already encrypted packet. | ||||||
|  | func TestConnWriter_SendEncryptedDataPacket(t *testing.T) { | ||||||
|  | 	route, rRoute, _, _ := testConnWriter_getTestRoutes() | ||||||
|  |  | ||||||
|  | 	writer := &testUDPAddrPortWriter{} | ||||||
|  | 	w := newConnWriter(writer, rRoute.IP) | ||||||
|  | 	in := []byte("Hello world!") | ||||||
|  |  | ||||||
|  | 	w.SendEncryptedDataPacket(in, route) | ||||||
|  |  | ||||||
|  | 	out := writer.Written() | ||||||
|  | 	if len(out) != 1 { | ||||||
|  | 		t.Fatal(out) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if out[0].Addr != route.RemoteAddr { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !bytes.Equal(out[0].Data, in) { | ||||||
|  | 		t.Fatal(out[0]) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										58
									
								
								peer/controlmessage.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								peer/controlmessage.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"vppn/m" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type controlMsg[T any] struct { | ||||||
|  | 	SrcIP   byte | ||||||
|  | 	SrcAddr netip.AddrPort | ||||||
|  | 	// TODO: RecvdAt int64 // Unixmilli. | ||||||
|  | 	Packet T | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) { | ||||||
|  | 	switch buf[0] { | ||||||
|  |  | ||||||
|  | 	case packetTypeSyn: | ||||||
|  | 		packet, err := parseSynPacket(buf) | ||||||
|  | 		return controlMsg[synPacket]{ | ||||||
|  | 			SrcIP:   srcIP, | ||||||
|  | 			SrcAddr: srcAddr, | ||||||
|  | 			Packet:  packet, | ||||||
|  | 		}, err | ||||||
|  |  | ||||||
|  | 	case packetTypeAck: | ||||||
|  | 		packet, err := parseAckPacket(buf) | ||||||
|  | 		return controlMsg[ackPacket]{ | ||||||
|  | 			SrcIP:   srcIP, | ||||||
|  | 			SrcAddr: srcAddr, | ||||||
|  | 			Packet:  packet, | ||||||
|  | 		}, err | ||||||
|  |  | ||||||
|  | 	case packetTypeProbe: | ||||||
|  | 		packet, err := parseProbePacket(buf) | ||||||
|  | 		return controlMsg[probePacket]{ | ||||||
|  | 			SrcIP:   srcIP, | ||||||
|  | 			SrcAddr: srcAddr, | ||||||
|  | 			Packet:  packet, | ||||||
|  | 		}, err | ||||||
|  |  | ||||||
|  | 	default: | ||||||
|  | 		return nil, errUnknownPacketType | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type peerUpdateMsg struct { | ||||||
|  | 	PeerIP byte | ||||||
|  | 	Peer   *m.Peer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type pingTimerMsg struct{} | ||||||
							
								
								
									
										113
									
								
								peer/crypto.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								peer/crypto.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"log" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/nacl/box" | ||||||
|  | 	"golang.org/x/crypto/nacl/sign" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type cryptoKeys struct { | ||||||
|  | 	PubKey      []byte | ||||||
|  | 	PrivKey     []byte | ||||||
|  | 	PubSignKey  []byte | ||||||
|  | 	PrivSignKey []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func generateKeys() cryptoKeys { | ||||||
|  | 	pubKey, privKey, err := box.GenerateKey(rand.Reader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to generate encryption keys: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to generate signing keys: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // Route must have a ControlCipher. | ||||||
|  | func encryptControlPacket( | ||||||
|  | 	localIP byte, | ||||||
|  | 	route *peerRoute, | ||||||
|  | 	pkt marshaller, | ||||||
|  | 	tmp []byte, | ||||||
|  | 	out []byte, | ||||||
|  | ) []byte { | ||||||
|  | 	h := header{ | ||||||
|  | 		StreamID: controlStreamID, | ||||||
|  | 		Counter:  atomic.AddUint64(route.Counter, 1), | ||||||
|  | 		SourceIP: localIP, | ||||||
|  | 		DestIP:   route.IP, | ||||||
|  | 	} | ||||||
|  | 	tmp = pkt.Marshal(tmp) | ||||||
|  | 	return route.ControlCipher.Encrypt(h, tmp, out) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Returns a controlMsg[PacketType]. Route must have ControlCipher. | ||||||
|  | func decryptControlPacket( | ||||||
|  | 	route *peerRoute, | ||||||
|  | 	fromAddr netip.AddrPort, | ||||||
|  | 	h header, | ||||||
|  | 	encrypted []byte, | ||||||
|  | 	tmp []byte, | ||||||
|  | ) (any, error) { | ||||||
|  | 	out, ok := route.ControlCipher.Decrypt(encrypted, tmp) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errDecryptionFailed | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if route.DupCheck.IsDup(h.Counter) { | ||||||
|  | 		return nil, errDuplicateSeqNum | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg, err := parseControlMsg(h.SourceIP, fromAddr, out) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return msg, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func encryptDataPacket( | ||||||
|  | 	localIP byte, | ||||||
|  | 	destIP byte, | ||||||
|  | 	route *peerRoute, | ||||||
|  | 	data []byte, | ||||||
|  | 	out []byte, | ||||||
|  | ) []byte { | ||||||
|  | 	h := header{ | ||||||
|  | 		StreamID: dataStreamID, | ||||||
|  | 		Counter:  atomic.AddUint64(route.Counter, 1), | ||||||
|  | 		SourceIP: localIP, | ||||||
|  | 		DestIP:   destIP, | ||||||
|  | 	} | ||||||
|  | 	return route.DataCipher.Encrypt(h, data, out) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func decryptDataPacket( | ||||||
|  | 	route *peerRoute, | ||||||
|  | 	h header, | ||||||
|  | 	encrypted []byte, | ||||||
|  | 	out []byte, | ||||||
|  | ) ([]byte, error) { | ||||||
|  | 	dec, ok := route.DataCipher.Decrypt(encrypted, out) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errDecryptionFailed | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if route.DupCheck.IsDup(h.Counter) { | ||||||
|  | 		return nil, errDuplicateSeqNum | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return dec, nil | ||||||
|  | } | ||||||
							
								
								
									
										213
									
								
								peer/crypto_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								peer/crypto_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,213 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"errors" | ||||||
|  | 	"net/netip" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func newRoutePairForTesting() (*peerRoute, *peerRoute) { | ||||||
|  | 	keys1 := generateKeys() | ||||||
|  | 	keys2 := generateKeys() | ||||||
|  |  | ||||||
|  | 	r1 := newPeerRoute(1) | ||||||
|  | 	r1.PubSignKey = keys1.PubSignKey | ||||||
|  | 	r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) | ||||||
|  | 	r1.DataCipher = newDataCipher() | ||||||
|  |  | ||||||
|  | 	r2 := newPeerRoute(2) | ||||||
|  | 	r2.PubSignKey = keys2.PubSignKey | ||||||
|  | 	r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) | ||||||
|  | 	r2.DataCipher = r1.DataCipher | ||||||
|  |  | ||||||
|  | 	return r1, r2 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptControlPacket(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		tmp    = make([]byte, bufferSize) | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	in := synPacket{ | ||||||
|  | 		TraceID:   newTraceID(), | ||||||
|  | 		SharedKey: r1.DataCipher.Key(), | ||||||
|  | 		Direct:    true, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc := encryptControlPacket(r1.IP, r2, in, tmp, out) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	msg, ok := iMsg.(controlMsg[synPacket]) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(msg.Packet, in) { | ||||||
|  | 		t.Fatal(msg) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptControlPacket_decryptionFailed(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		tmp    = make([]byte, bufferSize) | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	in := synPacket{ | ||||||
|  | 		TraceID:   newTraceID(), | ||||||
|  | 		SharedKey: r1.DataCipher.Key(), | ||||||
|  | 		Direct:    true, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc := encryptControlPacket(r1.IP, r2, in, tmp, out) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range enc { | ||||||
|  | 		x := bytes.Clone(enc) | ||||||
|  | 		x[i]++ | ||||||
|  | 		_, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) | ||||||
|  | 		if !errors.Is(err, errDecryptionFailed) { | ||||||
|  | 			t.Fatal(i, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptControlPacket_duplicate(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		tmp    = make([]byte, bufferSize) | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	in := synPacket{ | ||||||
|  | 		TraceID:   newTraceID(), | ||||||
|  | 		SharedKey: r1.DataCipher.Key(), | ||||||
|  | 		Direct:    true, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	enc := encryptControlPacket(r1.IP, r2, in, tmp, out) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) | ||||||
|  | 	if !errors.Is(err, errDuplicateSeqNum) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptControlPacket_invalidPacket(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		tmp    = make([]byte, bufferSize) | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	in := testPacket("hello!") | ||||||
|  |  | ||||||
|  | 	enc := encryptControlPacket(r1.IP, r2, in, tmp, out) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) | ||||||
|  | 	if !errors.Is(err, errUnknownPacketType) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptDataPacket(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 		data   = make([]byte, 1024) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	rand.Read(data) | ||||||
|  |  | ||||||
|  | 	enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !bytes.Equal(data, out) { | ||||||
|  | 		t.Fatal(data, out) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptDataPacket_incorrectCipher(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 		data   = make([]byte, 1024) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	rand.Read(data) | ||||||
|  |  | ||||||
|  | 	enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r1.DataCipher = newDataCipher() | ||||||
|  | 	_, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) | ||||||
|  | 	if !errors.Is(err, errDecryptionFailed) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecryptDataPacket_duplicate(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		r1, r2 = newRoutePairForTesting() | ||||||
|  | 		out    = make([]byte, bufferSize) | ||||||
|  | 		data   = make([]byte, 1024) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	rand.Read(data) | ||||||
|  |  | ||||||
|  | 	enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) | ||||||
|  | 	h, ok := parseHeader(enc) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(h, ok) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = decryptDataPacket(r1, h, enc, bytes.Clone(out)) | ||||||
|  | 	if !errors.Is(err, errDuplicateSeqNum) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										76
									
								
								peer/dupcheck.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								peer/dupcheck.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | type dupCheck struct { | ||||||
|  | 	bitSet | ||||||
|  | 	head        int | ||||||
|  | 	tail        int | ||||||
|  | 	headCounter uint64 | ||||||
|  | 	tailCounter uint64 // Also next expected counter value. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newDupCheck(headCounter uint64) *dupCheck { | ||||||
|  | 	return &dupCheck{ | ||||||
|  | 		headCounter: headCounter, | ||||||
|  | 		tailCounter: headCounter + 1, | ||||||
|  | 		tail:        1, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dc *dupCheck) IsDup(counter uint64) bool { | ||||||
|  |  | ||||||
|  | 	// Before head => it's late, say it's a dup. | ||||||
|  | 	if counter < dc.headCounter { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// It's within the counter bounds. | ||||||
|  | 	if counter < dc.tailCounter { | ||||||
|  | 		index := (int(counter-dc.headCounter) + dc.head) % bitSetSize | ||||||
|  | 		if dc.Get(index) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		dc.Set(index) | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// It's more than 1 beyond the tail. | ||||||
|  | 	delta := counter - dc.tailCounter | ||||||
|  |  | ||||||
|  | 	// Full clear. | ||||||
|  | 	if delta >= bitSetSize-1 { | ||||||
|  | 		dc.ClearAll() | ||||||
|  | 		dc.Set(0) | ||||||
|  |  | ||||||
|  | 		dc.tail = 1 | ||||||
|  | 		dc.head = 2 | ||||||
|  | 		dc.tailCounter = counter + 1 | ||||||
|  | 		dc.headCounter = dc.tailCounter - bitSetSize + 1 | ||||||
|  |  | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Clear if necessary. | ||||||
|  | 	for i := 0; i < int(delta); i++ { | ||||||
|  | 		dc.put(false) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dc.put(true) | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (dc *dupCheck) put(set bool) { | ||||||
|  | 	if set { | ||||||
|  | 		dc.Set(dc.tail) | ||||||
|  | 	} else { | ||||||
|  | 		dc.Clear(dc.tail) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	dc.tail = (dc.tail + 1) % bitSetSize | ||||||
|  | 	dc.tailCounter++ | ||||||
|  |  | ||||||
|  | 	if dc.head == dc.tail { | ||||||
|  | 		dc.head = (dc.head + 1) % bitSetSize | ||||||
|  | 		dc.headCounter++ | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										57
									
								
								peer/dupcheck_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								peer/dupcheck_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDupCheck(t *testing.T) { | ||||||
|  | 	dc := newDupCheck(0) | ||||||
|  |  | ||||||
|  | 	for i := range bitSetSize { | ||||||
|  | 		if dc.IsDup(uint64(i)) { | ||||||
|  | 			t.Fatal("!") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	type TestCase struct { | ||||||
|  | 		Counter uint64 | ||||||
|  | 		Dup     bool | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	testCases := []TestCase{ | ||||||
|  | 		{511, true}, | ||||||
|  | 		{0, true}, | ||||||
|  | 		{1, true}, | ||||||
|  | 		{2, true}, | ||||||
|  | 		{3, true}, | ||||||
|  | 		{63, true}, | ||||||
|  | 		{256, true}, | ||||||
|  | 		{510, true}, | ||||||
|  | 		{511, true}, | ||||||
|  | 		{512, false}, | ||||||
|  | 		{0, true}, | ||||||
|  | 		{512, true}, | ||||||
|  | 		{513, false}, | ||||||
|  | 		{517, false}, | ||||||
|  | 		{512, true}, | ||||||
|  | 		{513, true}, | ||||||
|  | 		{514, false}, | ||||||
|  | 		{515, false}, | ||||||
|  | 		{516, false}, | ||||||
|  | 		{517, true}, | ||||||
|  | 		{2512, false}, | ||||||
|  | 		{2512, true}, | ||||||
|  | 		{2001, true}, | ||||||
|  | 		{2002, false}, | ||||||
|  | 		{2002, true}, | ||||||
|  | 		{4000, false}, | ||||||
|  | 		{4000 - 511, true},  // Too old. | ||||||
|  | 		{4000 - 510, false}, // Just in the window. | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i, tc := range testCases { | ||||||
|  | 		if ok := dc.IsDup(tc.Counter); ok != tc.Dup { | ||||||
|  | 			t.Fatal(i, ok, tc) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										10
									
								
								peer/errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								peer/errors.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "errors" | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	errDecryptionFailed  = errors.New("decryption failed") | ||||||
|  | 	errDuplicateSeqNum   = errors.New("duplicate sequence number") | ||||||
|  | 	errMalformedPacket   = errors.New("malformed packet") | ||||||
|  | 	errUnknownPacketType = errors.New("unknown packet type") | ||||||
|  | ) | ||||||
							
								
								
									
										19
									
								
								peer/globals.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								peer/globals.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"net/netip" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	bufferSize            = 1536 | ||||||
|  | 	if_mtu                = 1200 | ||||||
|  | 	if_queue_len          = 2048 | ||||||
|  | 	controlCipherOverhead = 16 | ||||||
|  | 	dataCipherOverhead    = 16 | ||||||
|  | 	signOverhead          = 64 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( | ||||||
|  | 	netip.AddrFrom4([4]byte{224, 0, 0, 157}), | ||||||
|  | 	4560)) | ||||||
							
								
								
									
										49
									
								
								peer/header.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								peer/header.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "unsafe" | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	headerSize        = 12 | ||||||
|  | 	controlStreamID   = 2 | ||||||
|  | 	controlHeaderSize = 24 | ||||||
|  | 	dataStreamID      = 1 | ||||||
|  | 	dataHeaderSize    = 12 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type header struct { | ||||||
|  | 	Version  byte | ||||||
|  | 	StreamID byte | ||||||
|  | 	SourceIP byte | ||||||
|  | 	DestIP   byte | ||||||
|  | 	Counter  uint64 // Init with time.Now().Unix << 30 to ensure monotonic. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseHeader(b []byte) (h header, ok bool) { | ||||||
|  | 	if len(b) < headerSize { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	h.Version = b[0] | ||||||
|  | 	h.StreamID = b[1] | ||||||
|  | 	h.SourceIP = b[2] | ||||||
|  | 	h.DestIP = b[3] | ||||||
|  | 	h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) | ||||||
|  | 	return h, true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *header) Parse(b []byte) { | ||||||
|  | 	h.Version = b[0] | ||||||
|  | 	h.StreamID = b[1] | ||||||
|  | 	h.SourceIP = b[2] | ||||||
|  | 	h.DestIP = b[3] | ||||||
|  | 	h.Counter = *(*uint64)(unsafe.Pointer(&b[4])) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *header) Marshal(buf []byte) { | ||||||
|  | 	buf[0] = h.Version | ||||||
|  | 	buf[1] = h.StreamID | ||||||
|  | 	buf[2] = h.SourceIP | ||||||
|  | 	buf[3] = h.DestIP | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter | ||||||
|  | } | ||||||
							
								
								
									
										21
									
								
								peer/header_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								peer/header_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "testing" | ||||||
|  |  | ||||||
|  | func TestHeaderMarshalParse(t *testing.T) { | ||||||
|  | 	nIn := header{ | ||||||
|  | 		StreamID: 23, | ||||||
|  | 		Counter:  3212, | ||||||
|  | 		SourceIP: 34, | ||||||
|  | 		DestIP:   200, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	buf := make([]byte, headerSize) | ||||||
|  | 	nIn.Marshal(buf) | ||||||
|  |  | ||||||
|  | 	nOut := header{} | ||||||
|  | 	nOut.Parse(buf) | ||||||
|  | 	if nIn != nOut { | ||||||
|  | 		t.Fatal(nIn, nOut) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										100
									
								
								peer/ifreader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								peer/ifreader.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"log" | ||||||
|  | 	"sync/atomic" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type ifReader struct { | ||||||
|  | 	iface  io.Reader | ||||||
|  | 	routes [256]*atomic.Pointer[peerRoute] | ||||||
|  | 	relay  *atomic.Pointer[peerRoute] | ||||||
|  | 	sender dataPacketSender | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newIFReader( | ||||||
|  | 	iface io.Reader, | ||||||
|  | 	routes [256]*atomic.Pointer[peerRoute], | ||||||
|  | 	relay *atomic.Pointer[peerRoute], | ||||||
|  | 	sender dataPacketSender, | ||||||
|  | ) *ifReader { | ||||||
|  | 	return &ifReader{ | ||||||
|  | 		iface:  iface, | ||||||
|  | 		routes: routes, | ||||||
|  | 		relay:  relay, | ||||||
|  | 		sender: sender, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ifReader) Run() { | ||||||
|  | 	var ( | ||||||
|  | 		packet   = make([]byte, bufferSize) | ||||||
|  | 		remoteIP byte | ||||||
|  | 		ok       bool | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		packet = r.readNextPacket(packet) | ||||||
|  | 		if remoteIP, ok = r.parsePacket(packet); ok { | ||||||
|  | 			r.sendPacket(packet, remoteIP) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) { | ||||||
|  | 	route := r.routes[remoteIP].Load() | ||||||
|  | 	if !route.Up { | ||||||
|  | 		log.Printf("Route not connected: %d", remoteIP) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Direct path => early return. | ||||||
|  | 	if route.Direct { | ||||||
|  | 		r.sender.SendDataPacket(pkt, route) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if relay := r.relay.Load(); relay != nil && relay.Up { | ||||||
|  | 		r.sender.RelayDataPacket(pkt, route, relay) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Get next packet, returning packet, and destination ip. | ||||||
|  | func (r *ifReader) readNextPacket(buf []byte) []byte { | ||||||
|  | 	n, err := r.iface.Read(buf[:cap(buf)]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Fatalf("Failed to read from interface: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return buf[:n] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *ifReader) parsePacket(buf []byte) (byte, bool) { | ||||||
|  | 	n := len(buf) | ||||||
|  | 	if n == 0 { | ||||||
|  | 		return 0, false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	version := buf[0] >> 4 | ||||||
|  |  | ||||||
|  | 	switch version { | ||||||
|  | 	case 4: | ||||||
|  | 		if n < 20 { | ||||||
|  | 			log.Printf("Short IPv4 packet: %d", len(buf)) | ||||||
|  | 			return 0, false | ||||||
|  | 		} | ||||||
|  | 		return buf[19], true | ||||||
|  |  | ||||||
|  | 	case 6: | ||||||
|  | 		if len(buf) < 40 { | ||||||
|  | 			log.Printf("Short IPv6 packet: %d", len(buf)) | ||||||
|  | 			return 0, false | ||||||
|  | 		} | ||||||
|  | 		return buf[39], true | ||||||
|  |  | ||||||
|  | 	default: | ||||||
|  | 		log.Printf("Invalid IP packet version: %v", version) | ||||||
|  | 		return 0, false | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										232
									
								
								peer/ifreader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										232
									
								
								peer/ifreader_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,232 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net" | ||||||
|  | 	"reflect" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Test that we parse IPv4 packets correctly. | ||||||
|  | func TestIFReader_parsePacket_ipv4(t *testing.T) { | ||||||
|  | 	r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 1234) | ||||||
|  | 	pkt[0] = 4 << 4 | ||||||
|  | 	pkt[19] = 128 | ||||||
|  |  | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); !ok || ip != 128 { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test that we parse IPv6 packets correctly. | ||||||
|  | func TestIFReader_parsePacket_ipv6(t *testing.T) { | ||||||
|  | 	r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 1234) | ||||||
|  | 	pkt[0] = 6 << 4 | ||||||
|  | 	pkt[39] = 42 | ||||||
|  |  | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); !ok || ip != 42 { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test that empty packets work as expected. | ||||||
|  | func TestIFReader_parsePacket_emptyPacket(t *testing.T) { | ||||||
|  | 	r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 0) | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test that invalid IP versions fail. | ||||||
|  | func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) { | ||||||
|  | 	r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  |  | ||||||
|  | 	for i := byte(1); i < 16; i++ { | ||||||
|  | 		if i == 4 || i == 6 { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		pkt := make([]byte, 1234) | ||||||
|  | 		pkt[0] = i << 4 | ||||||
|  |  | ||||||
|  | 		if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 			t.Fatal(i, ip, ok) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test that short IPv4 packets fail. | ||||||
|  | func TestIFReader_parsePacket_shortIPv4(t *testing.T) { | ||||||
|  | 	r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 19) | ||||||
|  | 	pkt[0] = 4 << 4 | ||||||
|  |  | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test that short IPv6 packets fail. | ||||||
|  | func TestIFReader_parsePacket_shortIPv6(t *testing.T) { | ||||||
|  | 	r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  |  | ||||||
|  | 	pkt := make([]byte, 39) | ||||||
|  | 	pkt[0] = 6 << 4 | ||||||
|  |  | ||||||
|  | 	if ip, ok := r.parsePacket(pkt); ok { | ||||||
|  | 		t.Fatal(ip, ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Test that we can read a packet. | ||||||
|  | func TestIFReader_readNextpacket(t *testing.T) { | ||||||
|  | 	in, out := net.Pipe() | ||||||
|  | 	r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil) | ||||||
|  | 	defer in.Close() | ||||||
|  | 	defer out.Close() | ||||||
|  |  | ||||||
|  | 	go in.Write([]byte("hello world!")) | ||||||
|  |  | ||||||
|  | 	pkt := r.readNextPacket(make([]byte, bufferSize)) | ||||||
|  | 	if !bytes.Equal(pkt, []byte("hello world!")) { | ||||||
|  | 		t.Fatalf("%s", pkt) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type sentPacket struct { | ||||||
|  | 	Relayed bool | ||||||
|  | 	Packet  []byte | ||||||
|  | 	Route   peerRoute | ||||||
|  | 	Relay   peerRoute | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type sendPacketTestHarness struct { | ||||||
|  | 	Packets []sentPacket | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *peerRoute) { | ||||||
|  | 	h.Packets = append(h.Packets, sentPacket{ | ||||||
|  | 		Packet: bytes.Clone(pkt), | ||||||
|  | 		Route:  *route, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *peerRoute) { | ||||||
|  | 	h.Packets = append(h.Packets, sentPacket{ | ||||||
|  | 		Relayed: true, | ||||||
|  | 		Packet:  bytes.Clone(pkt), | ||||||
|  | 		Route:   *route, | ||||||
|  | 		Relay:   *relay, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) { | ||||||
|  | 	h := &sendPacketTestHarness{} | ||||||
|  |  | ||||||
|  | 	routes := [256]*atomic.Pointer[peerRoute]{} | ||||||
|  | 	for i := range routes { | ||||||
|  | 		routes[i] = &atomic.Pointer[peerRoute]{} | ||||||
|  | 		routes[i].Store(&peerRoute{}) | ||||||
|  | 	} | ||||||
|  | 	relay := &atomic.Pointer[peerRoute]{} | ||||||
|  | 	r := newIFReader(nil, routes, relay, h) | ||||||
|  | 	return r, h | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can send a packet directly. | ||||||
|  | func TestIFReader_sendPacket_direct(t *testing.T) { | ||||||
|  | 	r, h := newIFReaderForSendPacketTesting() | ||||||
|  |  | ||||||
|  | 	route := r.routes[2].Load() | ||||||
|  | 	route.Up = true | ||||||
|  | 	route.Direct = true | ||||||
|  |  | ||||||
|  | 	in := []byte("hello world") | ||||||
|  |  | ||||||
|  | 	r.sendPacket(in, 2) | ||||||
|  | 	if len(h.Packets) != 1 { | ||||||
|  | 		t.Fatal(h.Packets) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	expected := sentPacket{ | ||||||
|  | 		Relayed: false, | ||||||
|  | 		Packet:  in, | ||||||
|  | 		Route:   *route, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(h.Packets[0], expected) { | ||||||
|  | 		t.Fatal(h.Packets[0]) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we don't send a packet if route isn't up. | ||||||
|  | func TestIFReader_sendPacket_directNotUp(t *testing.T) { | ||||||
|  | 	r, h := newIFReaderForSendPacketTesting() | ||||||
|  |  | ||||||
|  | 	route := r.routes[2].Load() | ||||||
|  | 	route.Direct = true | ||||||
|  |  | ||||||
|  | 	in := []byte("hello world") | ||||||
|  |  | ||||||
|  | 	r.sendPacket(in, 2) | ||||||
|  | 	if len(h.Packets) != 0 { | ||||||
|  | 		t.Fatal(h.Packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we can send a packet via a relay. | ||||||
|  | func TestIFReader_sendPacket_relayed(t *testing.T) { | ||||||
|  | 	r, h := newIFReaderForSendPacketTesting() | ||||||
|  |  | ||||||
|  | 	route := r.routes[2].Load() | ||||||
|  | 	route.Up = true | ||||||
|  | 	route.Direct = false | ||||||
|  |  | ||||||
|  | 	relay := r.routes[3].Load() | ||||||
|  | 	r.relay.Store(relay) | ||||||
|  | 	relay.Up = true | ||||||
|  | 	relay.Direct = true | ||||||
|  |  | ||||||
|  | 	in := []byte("hello world") | ||||||
|  |  | ||||||
|  | 	r.sendPacket(in, 2) | ||||||
|  | 	if len(h.Packets) != 1 { | ||||||
|  | 		t.Fatal(h.Packets) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	expected := sentPacket{ | ||||||
|  | 		Relayed: true, | ||||||
|  | 		Packet:  in, | ||||||
|  | 		Route:   *route, | ||||||
|  | 		Relay:   *relay, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(h.Packets[0], expected) { | ||||||
|  | 		t.Fatal(h.Packets[0]) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we don't try to send on a nil relay IP. | ||||||
|  | func TestIFReader_sendPacket_nilRealy(t *testing.T) { | ||||||
|  | 	r, h := newIFReaderForSendPacketTesting() | ||||||
|  |  | ||||||
|  | 	route := r.routes[2].Load() | ||||||
|  | 	route.Up = true | ||||||
|  | 	route.Direct = false | ||||||
|  |  | ||||||
|  | 	in := []byte("hello world") | ||||||
|  |  | ||||||
|  | 	r.sendPacket(in, 2) | ||||||
|  | 	if len(h.Packets) != 0 { | ||||||
|  | 		t.Fatal(h.Packets) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								peer/ifwriter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								peer/ifwriter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "io" | ||||||
|  |  | ||||||
|  | type ifWriter io.Writer | ||||||
							
								
								
									
										28
									
								
								peer/interfaces.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								peer/interfaces.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import "net/netip" | ||||||
|  |  | ||||||
|  | type udpReader interface { | ||||||
|  | 	ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type udpWriter interface { | ||||||
|  | 	WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type marshaller interface { | ||||||
|  | 	Marshal([]byte) []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type dataPacketSender interface { | ||||||
|  | 	SendDataPacket(pkt []byte, route *peerRoute) | ||||||
|  | 	RelayDataPacket(pkt []byte, route, relay *peerRoute) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type encryptedPacketSender interface { | ||||||
|  | 	SendEncryptedDataPacket(pkt []byte, route *peerRoute) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type controlMsgHandler interface { | ||||||
|  | 	HandleControlMsg(pkt any) | ||||||
|  | } | ||||||
							
								
								
									
										62
									
								
								peer/mcwriter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								peer/mcwriter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,62 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/nacl/sign" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type mcUDPWriter interface { | ||||||
|  | 	WriteToUDP([]byte, *net.UDPAddr) (int, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte { | ||||||
|  | 	h := header{ | ||||||
|  | 		SourceIP: localIP, | ||||||
|  | 		DestIP:   255, | ||||||
|  | 	} | ||||||
|  | 	buf := make([]byte, headerSize) | ||||||
|  | 	h.Marshal(buf) | ||||||
|  | 	out := make([]byte, headerSize+signOverhead) | ||||||
|  | 	return sign.Sign(out[:0], buf, (*[64]byte)(signingKey)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) { | ||||||
|  | 	if len(pkt) != headerSize+signOverhead { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	h.Parse(pkt[signOverhead:]) | ||||||
|  | 	ok = true | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool { | ||||||
|  | 	_, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey)) | ||||||
|  | 	return ok | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type mcWriter struct { | ||||||
|  | 	conn            mcUDPWriter | ||||||
|  | 	discoveryPacket []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newMCWriter(conn mcUDPWriter, localIP byte, signingKey []byte) *mcWriter { | ||||||
|  | 	return &mcWriter{ | ||||||
|  | 		conn:            conn, | ||||||
|  | 		discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *mcWriter) SendLocalDiscovery() { | ||||||
|  | 	if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil { | ||||||
|  | 		log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										102
									
								
								peer/mcwriter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								peer/mcwriter_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,102 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"net" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // Testing that we can create and verify a local discovery packet. | ||||||
|  | func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	created := createLocalDiscoveryPacket(55, keys.PrivSignKey) | ||||||
|  |  | ||||||
|  | 	header, ok := headerFromLocalDiscoveryPacket(created) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | 	if header.SourceIP != 55 || header.DestIP != 255 { | ||||||
|  | 		t.Fatal(header) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) { | ||||||
|  | 		t.Fatal("Not valid") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that we don't try to parse short packets. | ||||||
|  | func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	created := createLocalDiscoveryPacket(55, keys.PrivSignKey) | ||||||
|  |  | ||||||
|  | 	_, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1]) | ||||||
|  | 	if ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Testing that modifying a packet makes it invalid. | ||||||
|  | func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  |  | ||||||
|  | 	created := createLocalDiscoveryPacket(55, keys.PrivSignKey) | ||||||
|  | 	buf := make([]byte, 1024) | ||||||
|  | 	for i := range created { | ||||||
|  | 		modified := bytes.Clone(created) | ||||||
|  | 		modified[i]++ | ||||||
|  | 		if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) { | ||||||
|  | 			t.Fatal("Verification should have failed.") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type testUDPWriter struct { | ||||||
|  | 	written [][]byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { | ||||||
|  | 	w.written = append(w.written, bytes.Clone(b)) | ||||||
|  | 	return len(b), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *testUDPWriter) Written() [][]byte { | ||||||
|  | 	out := w.written | ||||||
|  | 	w.written = [][]byte{} | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // Testing that the mcWriter sends local discovery packets as expected. | ||||||
|  | func TestMCWriter_SendLocalDiscovery(t *testing.T) { | ||||||
|  | 	keys := generateKeys() | ||||||
|  | 	writer := &testUDPWriter{} | ||||||
|  |  | ||||||
|  | 	mcw := newMCWriter(writer, 42, keys.PrivSignKey) | ||||||
|  | 	mcw.SendLocalDiscovery() | ||||||
|  |  | ||||||
|  | 	out := writer.Written() | ||||||
|  | 	if len(out) != 1 { | ||||||
|  | 		t.Fatal(out) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	pkt := out[0] | ||||||
|  |  | ||||||
|  | 	header, ok := headerFromLocalDiscoveryPacket(pkt) | ||||||
|  | 	if !ok { | ||||||
|  | 		t.Fatal(ok) | ||||||
|  | 	} | ||||||
|  | 	if header.SourceIP != 42 || header.DestIP != 255 { | ||||||
|  | 		t.Fatal(header) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) { | ||||||
|  | 		t.Fatal("Verification should succeed.") | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										190
									
								
								peer/packets-util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								peer/packets-util.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,190 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"sync/atomic" | ||||||
|  | 	"time" | ||||||
|  | 	"unsafe" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1 | ||||||
|  |  | ||||||
|  | func newTraceID() uint64 { | ||||||
|  | 	return atomic.AddUint64(&traceIDCounter, 1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type binWriter struct { | ||||||
|  | 	b []byte | ||||||
|  | 	i int | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newBinWriter(buf []byte) *binWriter { | ||||||
|  | 	buf = buf[:cap(buf)] | ||||||
|  | 	return &binWriter{buf, 0} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Bool(b bool) *binWriter { | ||||||
|  | 	if b { | ||||||
|  | 		return w.Byte(1) | ||||||
|  | 	} | ||||||
|  | 	return w.Byte(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Byte(b byte) *binWriter { | ||||||
|  | 	w.b[w.i] = b | ||||||
|  | 	w.i++ | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) SharedKey(key [32]byte) *binWriter { | ||||||
|  | 	copy(w.b[w.i:w.i+32], key[:]) | ||||||
|  | 	w.i += 32 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Uint16(x uint16) *binWriter { | ||||||
|  | 	*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x | ||||||
|  | 	w.i += 2 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Uint64(x uint64) *binWriter { | ||||||
|  | 	*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x | ||||||
|  | 	w.i += 8 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Int64(x int64) *binWriter { | ||||||
|  | 	*(*int64)(unsafe.Pointer(&w.b[w.i])) = x | ||||||
|  | 	w.i += 8 | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter { | ||||||
|  | 	w.Bool(addrPort.IsValid()) | ||||||
|  | 	addr := addrPort.Addr().As16() | ||||||
|  | 	copy(w.b[w.i:w.i+16], addr[:]) | ||||||
|  | 	w.i += 16 | ||||||
|  | 	return w.Uint16(addrPort.Port()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter { | ||||||
|  | 	for _, addrPort := range l { | ||||||
|  | 		w.AddrPort(addrPort) | ||||||
|  | 	} | ||||||
|  | 	return w | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *binWriter) Build() []byte { | ||||||
|  | 	return w.b[:w.i] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type binReader struct { | ||||||
|  | 	b   []byte | ||||||
|  | 	i   int | ||||||
|  | 	err error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newBinReader(buf []byte) *binReader { | ||||||
|  | 	return &binReader{b: buf} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) hasBytes(n int) bool { | ||||||
|  | 	if r.err != nil || (len(r.b)-r.i) < n { | ||||||
|  | 		r.err = errMalformedPacket | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Bool(b *bool) *binReader { | ||||||
|  | 	var bb byte | ||||||
|  | 	r.Byte(&bb) | ||||||
|  | 	*b = bb != 0 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Byte(b *byte) *binReader { | ||||||
|  | 	if !r.hasBytes(1) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*b = r.b[r.i] | ||||||
|  | 	r.i++ | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) SharedKey(x *[32]byte) *binReader { | ||||||
|  | 	if !r.hasBytes(32) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = ([32]byte)(r.b[r.i : r.i+32]) | ||||||
|  | 	r.i += 32 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Uint16(x *uint16) *binReader { | ||||||
|  | 	if !r.hasBytes(2) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = *(*uint16)(unsafe.Pointer(&r.b[r.i])) | ||||||
|  | 	r.i += 2 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Uint64(x *uint64) *binReader { | ||||||
|  | 	if !r.hasBytes(8) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = *(*uint64)(unsafe.Pointer(&r.b[r.i])) | ||||||
|  | 	r.i += 8 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Int64(x *int64) *binReader { | ||||||
|  | 	if !r.hasBytes(8) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  | 	*x = *(*int64)(unsafe.Pointer(&r.b[r.i])) | ||||||
|  | 	r.i += 8 | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) AddrPort(x *netip.AddrPort) *binReader { | ||||||
|  | 	if !r.hasBytes(19) { | ||||||
|  | 		return r | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var ( | ||||||
|  | 		valid bool | ||||||
|  | 		port  uint16 | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	r.Bool(&valid) | ||||||
|  | 	addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap() | ||||||
|  | 	r.i += 16 | ||||||
|  |  | ||||||
|  | 	r.Uint16(&port) | ||||||
|  |  | ||||||
|  | 	if valid { | ||||||
|  | 		*x = netip.AddrPortFrom(addr, port) | ||||||
|  | 	} else { | ||||||
|  | 		*x = netip.AddrPort{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader { | ||||||
|  | 	for i := range x { | ||||||
|  | 		r.AddrPort(&x[i]) | ||||||
|  | 	} | ||||||
|  | 	return r | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *binReader) Error() error { | ||||||
|  | 	return r.err | ||||||
|  | } | ||||||
							
								
								
									
										56
									
								
								peer/packets-util_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								peer/packets-util_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestBinWriteRead(t *testing.T) { | ||||||
|  | 	buf := make([]byte, 1024) | ||||||
|  |  | ||||||
|  | 	type Item struct { | ||||||
|  | 		Type     byte | ||||||
|  | 		TraceID  uint64 | ||||||
|  | 		Addrs    [8]netip.AddrPort | ||||||
|  | 		DestAddr netip.AddrPort | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	in := Item{ | ||||||
|  | 		1, | ||||||
|  | 		2, | ||||||
|  | 		[8]netip.AddrPort{}, | ||||||
|  | 		netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20) | ||||||
|  | 	in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22) | ||||||
|  | 	in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23) | ||||||
|  | 	in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24) | ||||||
|  | 	in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25) | ||||||
|  | 	in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26) | ||||||
|  | 	in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27) | ||||||
|  |  | ||||||
|  | 	buf = newBinWriter(buf). | ||||||
|  | 		Byte(in.Type). | ||||||
|  | 		Uint64(in.TraceID). | ||||||
|  | 		AddrPort(in.DestAddr). | ||||||
|  | 		AddrPortArray(in.Addrs). | ||||||
|  | 		Build() | ||||||
|  |  | ||||||
|  | 	out := Item{} | ||||||
|  |  | ||||||
|  | 	err := newBinReader(buf). | ||||||
|  | 		Byte(&out.Type). | ||||||
|  | 		Uint64(&out.TraceID). | ||||||
|  | 		AddrPort(&out.DestAddr). | ||||||
|  | 		AddrPortArray(&out.Addrs). | ||||||
|  | 		Error() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !reflect.DeepEqual(in, out) { | ||||||
|  | 		t.Fatal(in, out) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										123
									
								
								peer/packets.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								peer/packets.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	packetTypeSyn = iota + 1 | ||||||
|  | 	packetTypeSynAck | ||||||
|  | 	packetTypeAck | ||||||
|  | 	packetTypeProbe | ||||||
|  | 	packetTypeAddrDiscovery | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type synPacket struct { | ||||||
|  | 	TraceID uint64 // TraceID to match response w/ request. | ||||||
|  | 	// TODO: SentAt int64 // Unixmilli. | ||||||
|  | 	SharedKey     [32]byte // Our shared key. | ||||||
|  | 	Direct        bool | ||||||
|  | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p synPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeSyn). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		SharedKey(p.SharedKey). | ||||||
|  | 		Bool(p.Direct). | ||||||
|  | 		AddrPort(p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[7]). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseSynPacket(buf []byte) (p synPacket, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		SharedKey(&p.SharedKey). | ||||||
|  | 		Bool(&p.Direct). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[7]). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type ackPacket struct { | ||||||
|  | 	TraceID       uint64 | ||||||
|  | 	ToAddr        netip.AddrPort | ||||||
|  | 	PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p ackPacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeAck). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		AddrPort(p.ToAddr). | ||||||
|  | 		AddrPort(p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(p.PossibleAddrs[7]). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseAckPacket(buf []byte) (p ackPacket, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		AddrPort(&p.ToAddr). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[0]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[1]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[2]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[3]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[4]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[5]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[6]). | ||||||
|  | 		AddrPort(&p.PossibleAddrs[7]). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | // A probeReqPacket is sent from a client to a server to determine if direct | ||||||
|  | // UDP communication can be used. | ||||||
|  | type probePacket struct { | ||||||
|  | 	TraceID uint64 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p probePacket) Marshal(buf []byte) []byte { | ||||||
|  | 	return newBinWriter(buf). | ||||||
|  | 		Byte(packetTypeProbe). | ||||||
|  | 		Uint64(p.TraceID). | ||||||
|  | 		Build() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseProbePacket(buf []byte) (p probePacket, err error) { | ||||||
|  | 	err = newBinReader(buf[1:]). | ||||||
|  | 		Uint64(&p.TraceID). | ||||||
|  | 		Error() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type localDiscoveryPacket struct{} | ||||||
							
								
								
									
										1
									
								
								peer/packets_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								peer/packets_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | package peer | ||||||
							
								
								
									
										29
									
								
								peer/state.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								peer/state.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | package peer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/netip" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type peerRoute struct { | ||||||
|  | 	IP            byte // VPN IP of peer (last byte). | ||||||
|  | 	Up            bool // True if data can be sent on the route. | ||||||
|  | 	Relay         bool // True if the peer is a relay. | ||||||
|  | 	Direct        bool // True if this is a direct connection. | ||||||
|  | 	PubSignKey    []byte | ||||||
|  | 	ControlCipher *controlCipher | ||||||
|  | 	DataCipher    *dataCipher | ||||||
|  | 	RemoteAddr    netip.AddrPort // Remote address if directly connected. | ||||||
|  |  | ||||||
|  | 	Counter  *uint64   // For sending to. Atomic access only. | ||||||
|  | 	DupCheck *dupCheck // For receiving from. Not safe for concurrent use. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newPeerRoute(ip byte) *peerRoute { | ||||||
|  | 	counter := uint64(time.Now().Unix()<<30 + 1) | ||||||
|  | 	return &peerRoute{ | ||||||
|  | 		IP:       ip, | ||||||
|  | 		Counter:  &counter, | ||||||
|  | 		DupCheck: newDupCheck(0), | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user