diff --git a/peer/connreader_test.go b/peer/connreader_test.go index 7ef4ad8..5d91547 100644 --- a/peer/connreader_test.go +++ b/peer/connreader_test.go @@ -1,6 +1,13 @@ package peer -/* +import ( + "bytes" + "net/netip" + "reflect" + "sync/atomic" + "testing" +) + type mockIfWriter struct { Written [][]byte } @@ -135,7 +142,7 @@ func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) var header header header.Parse(encrypted) header.StreamID = 100 @@ -154,7 +161,8 @@ func TestConnReader_handleControlPacket_noCipher(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + //encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(1, h.Remote, pkt, newBuf(), newBuf()) var header header header.Parse(encrypted) header.SourceIP = 10 @@ -173,7 +181,7 @@ func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) var header header header.Parse(encrypted) header.DestIP++ @@ -192,7 +200,7 @@ func TestConnReader_handleControlPacket_modified(t *testing.T) { pkt := synPacket{TraceID: 1234} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) encrypted[len(encrypted)-1]++ h.WRemote.writeTo(encrypted, netip.AddrPort{}) @@ -202,19 +210,21 @@ func TestConnReader_handleControlPacket_modified(t *testing.T) { } } -type emptyPacket struct{} +type unknownPacket struct{} -func (p emptyPacket) Marshal(buf []byte) []byte { - return buf[:0] +func (p unknownPacket) Marshal(buf []byte) []byte { + buf = buf[:1] + buf[0] = 100 + return buf } // Testing that an empty control packet is ignored. -func TestConnReader_handleControlPacket_empty(t *testing.T) { +func TestConnReader_handleControlPacket_unknownPacketType(t *testing.T) { h := newConnReadeTestHarness() - pkt := emptyPacket{} + pkt := unknownPacket{} - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) + encrypted := encryptControlPacket(2, h.Remote, pkt, newBuf(), newBuf()) h.WRemote.writeTo(encrypted, netip.AddrPort{}) h.R.handleNextPacket() if len(h.Super.Messages) != 0 { @@ -228,13 +238,8 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { pkt := synPacket{TraceID: 1234} - log.Printf("%d", h.WRemote.counters[1]) h.WRemote.SendControlPacket(pkt, h.Remote) - log.Printf("%d", h.WRemote.counters[1]) - - // Rewind the counter. - h.WRemote.counters[1] = h.WRemote.counters[1] - 1 - log.Printf("%d", h.WRemote.counters[1]) + *h.Remote.Counter = *h.Remote.Counter - 1 h.WRemote.SendControlPacket(pkt, h.Remote) h.R.handleNextPacket() @@ -250,28 +255,7 @@ func TestConnReader_handleControlPacket_duplicate(t *testing.T) { } } -type invalidPacket struct { -} - -func (p invalidPacket) Marshal(b []byte) []byte { - out := b[:256] - clear(out) - return out -} - -// Testing that an invalid control packet is ignored (fails to parse). -func TestConnReader_handleControlPacket_cantParse(t *testing.T) { - h := newConnReadeTestHarness() - - pkt := invalidPacket{} - - encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote) - h.WRemote.writeTo(encrypted, netip.AddrPort{}) - h.R.handleNextPacket() - if len(h.Super.Messages) != 0 { - t.Fatal(h.Super.Messages) - } -} +/* // Testing that we can receive a data packet. func TestConnReader_handleDataPacket(t *testing.T) { diff --git a/peer/crypto.go b/peer/crypto.go index f41c8bf..3bc970f 100644 --- a/peer/crypto.go +++ b/peer/crypto.go @@ -52,6 +52,8 @@ func encryptControlPacket( } // Returns a controlMsg[PacketType]. Route must have ControlCipher. +// +// This function also drops packets with duplicate sequence numbers. func decryptControlPacket( route *peerRoute, fromAddr netip.AddrPort, diff --git a/peer/globals.go b/peer/globals.go index a4d8d65..4733ac8 100644 --- a/peer/globals.go +++ b/peer/globals.go @@ -17,3 +17,7 @@ const ( var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom( netip.AddrFrom4([4]byte{224, 0, 0, 157}), 4560)) + +func newBuf() []byte { + return make([]byte, bufferSize) +}