package peer import ( "bytes" "crypto/rand" "errors" "net/netip" "reflect" "testing" ) func newRoutePairForTesting() (*remotePeer, *remotePeer) { keys1 := generateKeys() keys2 := generateKeys() r1 := NewRemotePeer(1) r1.PubSignKey = keys1.PubSignKey r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey) r1.DataCipher = newDataCipher() r2 := NewRemotePeer(2) r2.PubSignKey = keys2.PubSignKey r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey) r2.DataCipher = r1.DataCipher return r1, r2 } func TestDecryptControlPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() tmp = make([]byte, bufferSize) out = make([]byte, bufferSize) ) in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, } enc := encryptControlPacket(r1.IP, r2, in, tmp, out) h := parseHeader(enc) iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) if err != nil { t.Fatal(err) } msg, ok := iMsg.(controlMsg[packetSyn]) if !ok { t.Fatal(ok) } if !reflect.DeepEqual(msg.Packet, in) { t.Fatal(msg) } } func TestDecryptControlPacket_decryptionFailed(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() tmp = make([]byte, bufferSize) out = make([]byte, bufferSize) ) in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, } enc := encryptControlPacket(r1.IP, r2, in, tmp, out) h := parseHeader(enc) for i := range enc { x := bytes.Clone(enc) x[i]++ _, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp) if !errors.Is(err, errDecryptionFailed) { t.Fatal(i, err) } } } func TestDecryptControlPacket_duplicate(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() tmp = make([]byte, bufferSize) out = make([]byte, bufferSize) ) in := packetSyn{ TraceID: newTraceID(), SharedKey: r1.DataCipher.Key(), Direct: true, } enc := encryptControlPacket(r1.IP, r2, in, tmp, out) h := parseHeader(enc) if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil { t.Fatal(err) } _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) if !errors.Is(err, errDuplicateSeqNum) { t.Fatal(err) } } /* func TestDecryptControlPacket_invalidPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() tmp = make([]byte, bufferSize) out = make([]byte, bufferSize) ) in := testPacket("hello!") enc := encryptControlPacket(r1.IP, r2, in, tmp, out) h := parseHeader(enc) _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp) if !errors.Is(err, errUnknownPacketType) { t.Fatal(err) } } */ func TestDecryptDataPacket(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() out = make([]byte, bufferSize) data = make([]byte, 1024) ) rand.Read(data) enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out) h := parseHeader(enc) out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out) if err != nil { t.Fatal(err) } if !bytes.Equal(data, out) { t.Fatal(data, out) } } func TestDecryptDataPacket_incorrectCipher(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() out = make([]byte, bufferSize) data = make([]byte, 1024) ) rand.Read(data) enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) h := parseHeader(enc) r1.DataCipher = newDataCipher() _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) if !errors.Is(err, errDecryptionFailed) { t.Fatal(err) } } func TestDecryptDataPacket_duplicate(t *testing.T) { var ( r1, r2 = newRoutePairForTesting() out = make([]byte, bufferSize) data = make([]byte, 1024) ) rand.Read(data) enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out)) h := parseHeader(enc) _, err := decryptDataPacket(r1, h, enc, bytes.Clone(out)) if err != nil { t.Fatal(err) } _, err = decryptDataPacket(r1, h, enc, bytes.Clone(out)) if !errors.Is(err, errDuplicateSeqNum) { t.Fatal(err) } }