wip
This commit is contained in:
parent
d79902a83b
commit
1a6503bbda
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
## Refactoring for Testability
|
## Refactoring for Testability
|
||||||
|
|
||||||
* [ ] connWriter
|
* [x] connWriter
|
||||||
* [ ] Separate send/relay calls
|
|
||||||
* [x] mcWriter
|
* [x] mcWriter
|
||||||
* [x] ifWriter
|
* [x] ifWriter
|
||||||
* [ ] ifReader
|
* [ ] ifReader (testing)
|
||||||
* [ ] connReader
|
* [ ] connReader
|
||||||
* [ ] mcReader
|
* [ ] mcReader
|
||||||
* [ ] hubPoller
|
* [ ] hubPoller
|
||||||
|
@ -68,18 +68,18 @@ func newConnWriter(conn udpAddrPortWriter, localIP byte) *connWriter {
|
|||||||
|
|
||||||
// Not safe for concurrent use. Should only be called by supervisor.
|
// Not safe for concurrent use. Should only be called by supervisor.
|
||||||
func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) {
|
func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) {
|
||||||
buf := pkt.Marshal(w.cBuf1)
|
buf := w.encryptControlPacket(pkt, route)
|
||||||
h := header{
|
|
||||||
StreamID: controlStreamID,
|
|
||||||
Counter: atomic.AddUint64(&w.counters[route.IP], 1),
|
|
||||||
SourceIP: w.localIP,
|
|
||||||
DestIP: route.IP,
|
|
||||||
}
|
|
||||||
buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2)
|
|
||||||
w.writeTo(buf, route.RemoteAddr)
|
w.writeTo(buf, route.RemoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Relay control packet. Routes must not be nil.
|
||||||
func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) {
|
func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) {
|
||||||
|
buf := w.encryptControlPacket(pkt, route)
|
||||||
|
w.relayPacket(buf, w.cBuf1, route, relay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypted packet will occupy cBuf2.
|
||||||
|
func (w *connWriter) encryptControlPacket(pkt marshaller, route *peerRoute) []byte {
|
||||||
buf := pkt.Marshal(w.cBuf1)
|
buf := pkt.Marshal(w.cBuf1)
|
||||||
h := header{
|
h := header{
|
||||||
StreamID: controlStreamID,
|
StreamID: controlStreamID,
|
||||||
@ -87,12 +87,11 @@ func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute)
|
|||||||
SourceIP: w.localIP,
|
SourceIP: w.localIP,
|
||||||
DestIP: route.IP,
|
DestIP: route.IP,
|
||||||
}
|
}
|
||||||
buf = route.ControlCipher.Encrypt(h, buf, w.cBuf2)
|
return route.ControlCipher.Encrypt(h, buf, w.cBuf2)
|
||||||
w.relayPacket(buf, w.cBuf1, route, relay)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not safe for concurrent use. Should only be called by ifReader.
|
// Not safe for concurrent use. Should only be called by ifReader.
|
||||||
func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) {
|
func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) {
|
||||||
h := header{
|
h := header{
|
||||||
StreamID: dataStreamID,
|
StreamID: dataStreamID,
|
||||||
Counter: atomic.AddUint64(&w.counters[route.IP], 1),
|
Counter: atomic.AddUint64(&w.counters[route.IP], 1),
|
||||||
@ -101,17 +100,22 @@ func (w *connWriter) SendDataPacket(pkt []byte, route, relay *peerRoute) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1)
|
enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1)
|
||||||
|
|
||||||
if route.Direct {
|
|
||||||
w.writeTo(enc, route.RemoteAddr)
|
w.writeTo(enc, route.RemoteAddr)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Relay a data packet. Routes must not be nil.
|
||||||
|
func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) {
|
||||||
|
h := header{
|
||||||
|
StreamID: dataStreamID,
|
||||||
|
Counter: atomic.AddUint64(&w.counters[route.IP], 1),
|
||||||
|
SourceIP: w.localIP,
|
||||||
|
DestIP: route.IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := route.DataCipher.Encrypt(h, pkt, w.dBuf1)
|
||||||
w.relayPacket(enc, w.dBuf2, route, relay)
|
w.relayPacket(enc, w.dBuf2, route, relay)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: RelayDataPacket
|
|
||||||
|
|
||||||
// Safe for concurrent use. Should only be called by connReader.
|
// Safe for concurrent use. Should only be called by connReader.
|
||||||
//
|
//
|
||||||
// This function will send pkt to the peer directly. This is used when a peer
|
// This function will send pkt to the peer directly. This is used when a peer
|
||||||
@ -122,10 +126,6 @@ func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) {
|
func (w *connWriter) relayPacket(data, buf []byte, route, relay *peerRoute) {
|
||||||
if relay == nil || !relay.Up {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h := header{
|
h := header{
|
||||||
StreamID: dataStreamID,
|
StreamID: dataStreamID,
|
||||||
Counter: atomic.AddUint64(&w.counters[relay.IP], 1),
|
Counter: atomic.AddUint64(&w.counters[relay.IP], 1),
|
||||||
|
@ -126,7 +126,7 @@ func TestConnWriter_SendControlPacket_direct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Testing if we can relay a packet via an intermediary.
|
// Testing if we can relay a packet via an intermediary.
|
||||||
func TestConnWriter_SendControlPacket_relay(t *testing.T) {
|
func TestConnWriter_RelayControlPacket_relay(t *testing.T) {
|
||||||
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
|
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
|
||||||
|
|
||||||
writer := &testUDPAddrPortWriter{}
|
writer := &testUDPAddrPortWriter{}
|
||||||
@ -159,40 +159,6 @@ func TestConnWriter_SendControlPacket_relay(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Testing that a nil relay doesn't cause an issue.
|
|
||||||
func TestConnWriter_SendControlPacket_relay_relayNil(t *testing.T) {
|
|
||||||
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
|
||||||
|
|
||||||
writer := &testUDPAddrPortWriter{}
|
|
||||||
w := newConnWriter(writer, rRoute.IP)
|
|
||||||
in := testPacket("hello world!")
|
|
||||||
|
|
||||||
w.RelayControlPacket(in, route, nil)
|
|
||||||
|
|
||||||
out := writer.Written()
|
|
||||||
if len(out) != 0 {
|
|
||||||
t.Fatal(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Testing that we don't send anything if the relay isn't up.
|
|
||||||
func TestConnWriter_SendControlPacket_relay_relayNotUp(t *testing.T) {
|
|
||||||
route, rRoute, relay, _ := testConnWriter_getTestRoutes()
|
|
||||||
relay.Up = false
|
|
||||||
|
|
||||||
writer := &testUDPAddrPortWriter{}
|
|
||||||
w := newConnWriter(writer, rRoute.IP)
|
|
||||||
in := testPacket("hello world!")
|
|
||||||
|
|
||||||
w.RelayControlPacket(in, route, relay)
|
|
||||||
|
|
||||||
out := writer.Written()
|
|
||||||
if len(out) != 0 {
|
|
||||||
t.Fatal(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Testing that we can send a data packet directly to a remote route.
|
// Testing that we can send a data packet directly to a remote route.
|
||||||
func TestConnWriter_SendDataPacket_direct(t *testing.T) {
|
func TestConnWriter_SendDataPacket_direct(t *testing.T) {
|
||||||
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
||||||
@ -202,7 +168,7 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) {
|
|||||||
w := newConnWriter(writer, rRoute.IP)
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
|
|
||||||
in := []byte("hello world!")
|
in := []byte("hello world!")
|
||||||
w.SendDataPacket(in, route, nil)
|
w.SendDataPacket(in, route)
|
||||||
|
|
||||||
out := writer.Written()
|
out := writer.Written()
|
||||||
if len(out) != 1 {
|
if len(out) != 1 {
|
||||||
@ -224,14 +190,14 @@ func TestConnWriter_SendDataPacket_direct(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Testing that we can relay a data packet via a relay.
|
// Testing that we can relay a data packet via a relay.
|
||||||
func TestConnWriter_SendDataPacket_relay(t *testing.T) {
|
func TestConnWriter_RelayDataPacket_relay(t *testing.T) {
|
||||||
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
|
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
|
||||||
|
|
||||||
writer := &testUDPAddrPortWriter{}
|
writer := &testUDPAddrPortWriter{}
|
||||||
w := newConnWriter(writer, rRoute.IP)
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
in := []byte("Hello world!")
|
in := []byte("Hello world!")
|
||||||
|
|
||||||
w.SendDataPacket(in, route, relay)
|
w.RelayDataPacket(in, route, relay)
|
||||||
|
|
||||||
out := writer.Written()
|
out := writer.Written()
|
||||||
if len(out) != 1 {
|
if len(out) != 1 {
|
||||||
@ -257,35 +223,26 @@ func TestConnWriter_SendDataPacket_relay(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Testing that we don't attempt to relay if the relay is nil.
|
// Testing that we can send an already encrypted packet.
|
||||||
func TestConnWriter_SendDataPacket_relay_relayNil(t *testing.T) {
|
func TestConnWriter_SendEncryptedDataPacket(t *testing.T) {
|
||||||
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
||||||
|
|
||||||
writer := &testUDPAddrPortWriter{}
|
writer := &testUDPAddrPortWriter{}
|
||||||
w := newConnWriter(writer, rRoute.IP)
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
in := []byte("Hello world!")
|
in := []byte("Hello world!")
|
||||||
|
|
||||||
w.SendDataPacket(in, route, nil)
|
w.SendEncryptedDataPacket(in, route)
|
||||||
|
|
||||||
out := writer.Written()
|
out := writer.Written()
|
||||||
if len(out) != 0 {
|
if len(out) != 1 {
|
||||||
t.Fatal(out)
|
t.Fatal(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if out[0].Addr != route.RemoteAddr {
|
||||||
|
t.Fatal(out[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Testing that we don't attempt to relay if the relay isn't up.
|
if !bytes.Equal(out[0].Data, in) {
|
||||||
func TestConnWriter_SendDataPacket_relay_relayNotUp(t *testing.T) {
|
t.Fatal(out[0])
|
||||||
route, rRoute, relay, _ := testConnWriter_getTestRoutes()
|
|
||||||
relay.Up = false
|
|
||||||
|
|
||||||
writer := &testUDPAddrPortWriter{}
|
|
||||||
w := newConnWriter(writer, rRoute.IP)
|
|
||||||
in := []byte("Hello world!")
|
|
||||||
|
|
||||||
w.SendDataPacket(in, route, relay)
|
|
||||||
|
|
||||||
out := writer.Written()
|
|
||||||
if len(out) != 0 {
|
|
||||||
t.Fatal(out)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,14 +38,14 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
|
|||||||
delta := counter - dc.tailCounter
|
delta := counter - dc.tailCounter
|
||||||
|
|
||||||
// Full clear.
|
// Full clear.
|
||||||
if delta >= bitSetSize {
|
if delta >= bitSetSize-1 {
|
||||||
dc.ClearAll()
|
dc.ClearAll()
|
||||||
dc.Set(0)
|
dc.Set(0)
|
||||||
|
|
||||||
dc.tail = 1
|
dc.tail = 1
|
||||||
dc.head = 2
|
dc.head = 2
|
||||||
dc.tailCounter = counter + 1
|
dc.tailCounter = counter + 1
|
||||||
dc.headCounter = dc.tailCounter - bitSetSize
|
dc.headCounter = dc.tailCounter - bitSetSize + 1
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,18 @@ type header struct {
|
|||||||
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
|
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseHeader(b []byte) (h header, ok bool) {
|
||||||
|
if len(b) < headerSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Version = b[0]
|
||||||
|
h.StreamID = b[1]
|
||||||
|
h.SourceIP = b[2]
|
||||||
|
h.DestIP = b[3]
|
||||||
|
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
|
||||||
|
return h, true
|
||||||
|
}
|
||||||
|
|
||||||
func (h *header) Parse(b []byte) {
|
func (h *header) Parse(b []byte) {
|
||||||
h.Version = b[0]
|
h.Version = b[0]
|
||||||
h.StreamID = b[1]
|
h.StreamID = b[1]
|
||||||
|
@ -57,7 +57,7 @@ func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if relay := r.relay.Load(); relay != nil {
|
if relay := r.relay.Load(); relay != nil && relay.Up {
|
||||||
r.relayDataPacket(pkt, route, relay)
|
r.relayDataPacket(pkt, route, relay)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
21
peer/bitset.go
Normal file
21
peer/bitset.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
const bitSetSize = 512 // Multiple of 64.
|
||||||
|
|
||||||
|
type bitSet [bitSetSize / 64]uint64
|
||||||
|
|
||||||
|
func (bs *bitSet) Set(i int) {
|
||||||
|
bs[i/64] |= 1 << (i % 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bs *bitSet) Clear(i int) {
|
||||||
|
bs[i/64] &= ^(1 << (i % 64))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bs *bitSet) ClearAll() {
|
||||||
|
clear(bs[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bs *bitSet) Get(i int) bool {
|
||||||
|
return bs[i/64]&(1<<(i%64)) != 0
|
||||||
|
}
|
48
peer/bitset_test.go
Normal file
48
peer/bitset_test.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBitSet(t *testing.T) {
|
||||||
|
state := make([]bool, bitSetSize)
|
||||||
|
for i := range state {
|
||||||
|
state[i] = rand.Float32() > 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
bs := bitSet{}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if state[i] {
|
||||||
|
bs.Set(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if bs.Get(i) != state[i] {
|
||||||
|
t.Fatal(i, state[i], bs.Get(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if rand.Float32() > 0.5 {
|
||||||
|
state[i] = false
|
||||||
|
bs.Clear(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if bs.Get(i) != state[i] {
|
||||||
|
t.Fatal(i, state[i], bs.Get(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bs.ClearAll()
|
||||||
|
|
||||||
|
for i := range state {
|
||||||
|
if bs.Get(i) {
|
||||||
|
t.Fatal(i, bs.Get(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
26
peer/cipher-control.go
Normal file
26
peer/cipher-control.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "golang.org/x/crypto/nacl/box"
|
||||||
|
|
||||||
|
type controlCipher struct {
|
||||||
|
sharedKey [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newControlCipher(privKey, pubKey []byte) *controlCipher {
|
||||||
|
shared := [32]byte{}
|
||||||
|
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
|
||||||
|
return &controlCipher{shared}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *controlCipher) Encrypt(h header, data, out []byte) []byte {
|
||||||
|
const s = controlHeaderSize
|
||||||
|
out = out[:s+controlCipherOverhead+len(data)]
|
||||||
|
h.Marshal(out[:s])
|
||||||
|
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
|
||||||
|
const s = controlHeaderSize
|
||||||
|
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
|
||||||
|
}
|
122
peer/cipher-control_test.go
Normal file
122
peer/cipher-control_test.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/nacl/box"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newControlCipherForTesting() (c1, c2 *controlCipher) {
|
||||||
|
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newControlCipher(privKey1[:], pubKey2[:]),
|
||||||
|
newControlCipher(privKey2[:], pubKey1[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlCipher(t *testing.T) {
|
||||||
|
c1, c2 := newControlCipherForTesting()
|
||||||
|
|
||||||
|
maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
|
||||||
|
rand.Read(maxSizePlaintext)
|
||||||
|
|
||||||
|
testCases := [][]byte{
|
||||||
|
make([]byte, 0),
|
||||||
|
{1},
|
||||||
|
{255},
|
||||||
|
{1, 2, 3, 4, 5},
|
||||||
|
[]byte("Hello world"),
|
||||||
|
maxSizePlaintext,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, plaintext := range testCases {
|
||||||
|
h1 := header{
|
||||||
|
StreamID: controlStreamID,
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
encrypted = c1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
|
||||||
|
h2 := header{}
|
||||||
|
h2.Parse(encrypted)
|
||||||
|
if !reflect.DeepEqual(h1, h2) {
|
||||||
|
t.Fatal(h1, h2)
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted, plaintext) {
|
||||||
|
t.Fatal("not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestControlCipher_ShortCiphertext(t *testing.T) {
|
||||||
|
c1, _ := newControlCipherForTesting()
|
||||||
|
shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
|
||||||
|
rand.Read(shortText)
|
||||||
|
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkControlCipher_Encrypt(b *testing.B) {
|
||||||
|
c1, _ := newControlCipherForTesting()
|
||||||
|
h1 := header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
encrypted = c1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkControlCipher_Decrypt(b *testing.B) {
|
||||||
|
c1, c2 := newControlCipherForTesting()
|
||||||
|
|
||||||
|
h1 := header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
encrypted = c1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
|
||||||
|
decrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
decrypted, _ = c2.Decrypt(encrypted, decrypted)
|
||||||
|
}
|
||||||
|
}
|
61
peer/cipher-data.go
Normal file
61
peer/cipher-data.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dataCipher struct {
|
||||||
|
key [32]byte
|
||||||
|
aead cipher.AEAD
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDataCipher() *dataCipher {
|
||||||
|
key := [32]byte{}
|
||||||
|
if _, err := rand.Read(key[:]); err != nil {
|
||||||
|
log.Fatalf("Failed to read random data: %v", err)
|
||||||
|
}
|
||||||
|
return newDataCipherFromKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDataCipherFromKey(key [32]byte) *dataCipher {
|
||||||
|
block, err := aes.NewCipher(key[:])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create new cipher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
aead, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create new GCM: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dataCipher{key: key, aead: aead}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *dataCipher) Key() [32]byte {
|
||||||
|
return sc.key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *dataCipher) Encrypt(h header, data, out []byte) []byte {
|
||||||
|
const s = dataHeaderSize
|
||||||
|
out = out[:s+dataCipherOverhead+len(data)]
|
||||||
|
h.Marshal(out[:s])
|
||||||
|
sc.aead.Seal(out[s:s], out[:s], data, nil)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
|
||||||
|
const s = dataHeaderSize
|
||||||
|
if len(encrypted) < s+dataCipherOverhead {
|
||||||
|
ok = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
|
||||||
|
ok = err == nil
|
||||||
|
return
|
||||||
|
}
|
141
peer/cipher-data_test.go
Normal file
141
peer/cipher-data_test.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
mrand "math/rand/v2"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDataCipher(t *testing.T) {
|
||||||
|
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(maxSizePlaintext)
|
||||||
|
|
||||||
|
testCases := [][]byte{
|
||||||
|
make([]byte, 0),
|
||||||
|
{1},
|
||||||
|
{255},
|
||||||
|
{1, 2, 3, 4, 5},
|
||||||
|
[]byte("Hello world"),
|
||||||
|
maxSizePlaintext,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, plaintext := range testCases {
|
||||||
|
h1 := header{
|
||||||
|
StreamID: dataStreamID,
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
h2 := header{}
|
||||||
|
h2.Parse(encrypted)
|
||||||
|
|
||||||
|
dc2 := newDataCipherFromKey(dc1.Key())
|
||||||
|
|
||||||
|
decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(plaintext, decrypted) {
|
||||||
|
t.Fatal("not equal")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(h1, h2) {
|
||||||
|
t.Fatalf("%v != %v", h1, h2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataCipher_ModifyCiphertext(t *testing.T) {
|
||||||
|
maxSizePlaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(maxSizePlaintext)
|
||||||
|
|
||||||
|
testCases := [][]byte{
|
||||||
|
make([]byte, 0),
|
||||||
|
{1},
|
||||||
|
{255},
|
||||||
|
{1, 2, 3, 4, 5},
|
||||||
|
[]byte("Hello world"),
|
||||||
|
maxSizePlaintext,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, plaintext := range testCases {
|
||||||
|
h1 := header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
encrypted[mrand.IntN(len(encrypted))]++
|
||||||
|
|
||||||
|
dc2 := newDataCipherFromKey(dc1.Key())
|
||||||
|
|
||||||
|
_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataCipher_ShortCiphertext(t *testing.T) {
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
|
||||||
|
rand.Read(shortText)
|
||||||
|
_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDataCipher_Encrypt(b *testing.B) {
|
||||||
|
h1 := header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDataCipher_Decrypt(b *testing.B) {
|
||||||
|
h1 := header{
|
||||||
|
Counter: 235153,
|
||||||
|
SourceIP: 4,
|
||||||
|
DestIP: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, bufferSize-dataHeaderSize-dataCipherOverhead)
|
||||||
|
rand.Read(plaintext)
|
||||||
|
|
||||||
|
encrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
dc1 := newDataCipher()
|
||||||
|
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
|
||||||
|
|
||||||
|
decrypted := make([]byte, bufferSize)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
decrypted, _ = dc1.Decrypt(encrypted, decrypted)
|
||||||
|
}
|
||||||
|
}
|
13
peer/cipher-discovery.go
Normal file
13
peer/cipher-discovery.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
/*
|
||||||
|
func signData(privKey *[64]byte, h header, data, out []byte) []byte {
|
||||||
|
out = out[:headerSize]
|
||||||
|
h.Marshal(out)
|
||||||
|
return sign.Sign(out, data, privKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openData(pubKey *[32]byte, signed, out []byte) (data []byte, ok bool) {
|
||||||
|
return sign.Open(out[:0], signed[headerSize:], pubKey)
|
||||||
|
}
|
||||||
|
*/
|
141
peer/connreader.go
Normal file
141
peer/connreader.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connReader struct {
|
||||||
|
conn udpReader
|
||||||
|
iface ifWriter
|
||||||
|
sender encryptedPacketSender
|
||||||
|
super controlMsgHandler
|
||||||
|
localIP byte
|
||||||
|
routes [256]*atomic.Pointer[peerRoute]
|
||||||
|
|
||||||
|
buf []byte
|
||||||
|
decBuf []byte
|
||||||
|
dupChecks [256]*dupCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnReader(
|
||||||
|
conn udpReader,
|
||||||
|
ifWriter ifWriter,
|
||||||
|
sender encryptedPacketSender,
|
||||||
|
super controlMsgHandler,
|
||||||
|
localIP byte,
|
||||||
|
routes [256]*atomic.Pointer[peerRoute],
|
||||||
|
) *connReader {
|
||||||
|
return &connReader{
|
||||||
|
conn: conn,
|
||||||
|
iface: ifWriter,
|
||||||
|
sender: sender,
|
||||||
|
super: super,
|
||||||
|
localIP: localIP,
|
||||||
|
routes: routes,
|
||||||
|
buf: make([]byte, bufferSize),
|
||||||
|
decBuf: make([]byte, bufferSize),
|
||||||
|
dupChecks: func() (out [256]*dupCheck) {
|
||||||
|
for i := range out {
|
||||||
|
out[i] = newDupCheck(0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *connReader) Run() {
|
||||||
|
for {
|
||||||
|
r.handleNextPacket()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *connReader) logf(s string, args ...any) {
|
||||||
|
log.Printf("[ConnReader] "+s, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *connReader) handleNextPacket() {
|
||||||
|
buf := r.buf[:bufferSize]
|
||||||
|
n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read from UDP port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < headerSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
|
||||||
|
|
||||||
|
buf = buf[:n]
|
||||||
|
h, ok := parseHeader(buf)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
route := r.routes[h.SourceIP].Load()
|
||||||
|
|
||||||
|
switch h.StreamID {
|
||||||
|
case controlStreamID:
|
||||||
|
r.handleControlPacket(route, remoteAddr, h, buf)
|
||||||
|
|
||||||
|
case dataStreamID:
|
||||||
|
r.handleDataPacket(route, h, buf)
|
||||||
|
|
||||||
|
default:
|
||||||
|
r.logf("Unknown stream ID: %d", h.StreamID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *connReader) handleControlPacket(
|
||||||
|
route *peerRoute,
|
||||||
|
addr netip.AddrPort,
|
||||||
|
h header,
|
||||||
|
enc []byte,
|
||||||
|
) {
|
||||||
|
if route.ControlCipher == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.DestIP != r.localIP {
|
||||||
|
r.logf("Incorrect destination IP on control packet: %d", h.DestIP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := decryptControlPacket(route, addr, h, enc, r.decBuf)
|
||||||
|
if err != nil {
|
||||||
|
r.logf("Failed to decrypt control packet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.super.HandleControlMsg(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *connReader) handleDataPacket(route *peerRoute, h header, enc []byte) {
|
||||||
|
if !route.Up {
|
||||||
|
r.logf("Not connected (recv).")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := decryptDataPacket(route, h, enc, r.decBuf)
|
||||||
|
if err != nil {
|
||||||
|
r.logf("Failed to decrypt data packet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.DestIP == r.localIP {
|
||||||
|
if _, err := r.iface.Write(data); err != nil {
|
||||||
|
log.Fatalf("Failed to write to interface: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
destRoute := r.routes[h.DestIP].Load()
|
||||||
|
if !destRoute.Up {
|
||||||
|
r.logf("Not connected (relay): %d", destRoute.IP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.sender.SendEncryptedDataPacket(data, destRoute)
|
||||||
|
}
|
318
peer/connreader_test.go
Normal file
318
peer/connreader_test.go
Normal file
@ -0,0 +1,318 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
/*
|
||||||
|
type mockIfWriter struct {
|
||||||
|
Written [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mockIfWriter) Write(b []byte) (int, error) {
|
||||||
|
w.Written = append(w.Written, bytes.Clone(b))
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockEncryptedPacket struct {
|
||||||
|
Packet []byte
|
||||||
|
Route *peerRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockEncryptedPacketSender struct {
|
||||||
|
Sent []mockEncryptedPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEncryptedPacketSender) SendEncryptedDataPacket(pkt []byte, route *peerRoute) {
|
||||||
|
m.Sent = append(m.Sent, mockEncryptedPacket{
|
||||||
|
Packet: bytes.Clone(pkt),
|
||||||
|
Route: route,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockControlMsgHandler struct {
|
||||||
|
Messages []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockControlMsgHandler) HandleControlMsg(pkt any) {
|
||||||
|
m.Messages = append(m.Messages, pkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpPipe struct {
|
||||||
|
packets chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPPipe() *udpPipe {
|
||||||
|
return &udpPipe{make(chan []byte, 1024)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *udpPipe) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
||||||
|
p.packets <- bytes.Clone(b)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *udpPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
|
||||||
|
packet := <-p.packets
|
||||||
|
copy(b, packet)
|
||||||
|
return len(packet), netip.AddrPort{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type connReaderTestHarness struct {
|
||||||
|
Pipe *udpPipe
|
||||||
|
R *connReader
|
||||||
|
WRemote *connWriter
|
||||||
|
WRelayRemote *connWriter
|
||||||
|
Remote *peerRoute
|
||||||
|
RelayRemote *peerRoute
|
||||||
|
IFace *mockIfWriter
|
||||||
|
Sender *mockEncryptedPacketSender
|
||||||
|
Super *mockControlMsgHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peer 2 is indirect, peer 3 is direct.
|
||||||
|
func newConnReadeTestHarness() (h connReaderTestHarness) {
|
||||||
|
pipe := newUDPPipe()
|
||||||
|
routes := [256]*atomic.Pointer[peerRoute]{}
|
||||||
|
for i := range routes {
|
||||||
|
routes[i] = &atomic.Pointer[peerRoute]{}
|
||||||
|
routes[i].Store(&peerRoute{})
|
||||||
|
}
|
||||||
|
|
||||||
|
local, remote, relayLocal, relayRemote := testConnWriter_getTestRoutes()
|
||||||
|
routes[2].Store(local)
|
||||||
|
routes[3].Store(relayLocal)
|
||||||
|
|
||||||
|
h.Pipe = pipe
|
||||||
|
h.WRemote = newConnWriter(pipe, 2)
|
||||||
|
h.WRelayRemote = newConnWriter(pipe, 3)
|
||||||
|
|
||||||
|
h.Remote = remote
|
||||||
|
h.RelayRemote = relayRemote
|
||||||
|
h.IFace = &mockIfWriter{}
|
||||||
|
h.Sender = &mockEncryptedPacketSender{}
|
||||||
|
h.Super = &mockControlMsgHandler{}
|
||||||
|
h.R = newConnReader(
|
||||||
|
pipe,
|
||||||
|
h.IFace,
|
||||||
|
h.Sender,
|
||||||
|
h.Super,
|
||||||
|
1,
|
||||||
|
routes)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can receive a control packet.
|
||||||
|
func TestConnReader_handleControlPacket(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := synPacket{TraceID: 1234}
|
||||||
|
|
||||||
|
h.WRemote.SendControlPacket(pkt, h.Remote)
|
||||||
|
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
|
||||||
|
if len(h.Super.Messages) != 1 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := h.Super.Messages[0].(controlMsg[synPacket])
|
||||||
|
if !reflect.DeepEqual(pkt, msg.Packet) {
|
||||||
|
t.Fatal(msg.Packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that a short packet is ignored.
|
||||||
|
func TestConnReader_handleNextPacket_short(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
h.Pipe.WriteToUDPAddrPort([]byte{1, 2, 3}, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that a packet with an unexpected stream ID is ignored.
|
||||||
|
func TestConnReader_handleNextPacket_unknownStreamID(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := synPacket{TraceID: 1234}
|
||||||
|
|
||||||
|
encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
|
||||||
|
var header header
|
||||||
|
header.Parse(encrypted)
|
||||||
|
header.StreamID = 100
|
||||||
|
header.Marshal(encrypted)
|
||||||
|
|
||||||
|
h.WRemote.writeTo(encrypted, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that control packet without matching control cipher is ignored.
|
||||||
|
func TestConnReader_handleControlPacket_noCipher(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := synPacket{TraceID: 1234}
|
||||||
|
|
||||||
|
encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
|
||||||
|
var header header
|
||||||
|
header.Parse(encrypted)
|
||||||
|
header.SourceIP = 10
|
||||||
|
header.Marshal(encrypted)
|
||||||
|
|
||||||
|
h.WRemote.writeTo(encrypted, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that control packet with incrrect destination IP is ignored.
|
||||||
|
func TestConnReader_handleControlPacket_incorrectDest(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := synPacket{TraceID: 1234}
|
||||||
|
|
||||||
|
encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
|
||||||
|
var header header
|
||||||
|
header.Parse(encrypted)
|
||||||
|
header.DestIP++
|
||||||
|
header.Marshal(encrypted)
|
||||||
|
|
||||||
|
h.WRemote.writeTo(encrypted, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that modified control packet is ignored.
|
||||||
|
func TestConnReader_handleControlPacket_modified(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := synPacket{TraceID: 1234}
|
||||||
|
|
||||||
|
encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
|
||||||
|
encrypted[len(encrypted)-1]++
|
||||||
|
|
||||||
|
h.WRemote.writeTo(encrypted, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type emptyPacket struct{}
|
||||||
|
|
||||||
|
func (p emptyPacket) Marshal(buf []byte) []byte {
|
||||||
|
return buf[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that an empty control packet is ignored.
|
||||||
|
func TestConnReader_handleControlPacket_empty(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := emptyPacket{}
|
||||||
|
|
||||||
|
encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
|
||||||
|
h.WRemote.writeTo(encrypted, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that a duplicate control packet is ignored.
|
||||||
|
func TestConnReader_handleControlPacket_duplicate(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := synPacket{TraceID: 1234}
|
||||||
|
|
||||||
|
log.Printf("%d", h.WRemote.counters[1])
|
||||||
|
h.WRemote.SendControlPacket(pkt, h.Remote)
|
||||||
|
log.Printf("%d", h.WRemote.counters[1])
|
||||||
|
|
||||||
|
// Rewind the counter.
|
||||||
|
h.WRemote.counters[1] = h.WRemote.counters[1] - 1
|
||||||
|
log.Printf("%d", h.WRemote.counters[1])
|
||||||
|
h.WRemote.SendControlPacket(pkt, h.Remote)
|
||||||
|
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
|
||||||
|
if len(h.Super.Messages) != 1 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := h.Super.Messages[0].(controlMsg[synPacket])
|
||||||
|
if !reflect.DeepEqual(pkt, msg.Packet) {
|
||||||
|
t.Fatal(msg.Packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type invalidPacket struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p invalidPacket) Marshal(b []byte) []byte {
|
||||||
|
out := b[:256]
|
||||||
|
clear(out)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that an invalid control packet is ignored (fails to parse).
|
||||||
|
func TestConnReader_handleControlPacket_cantParse(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := invalidPacket{}
|
||||||
|
|
||||||
|
encrypted := h.WRemote.encryptControlPacket(pkt, h.Remote)
|
||||||
|
h.WRemote.writeTo(encrypted, netip.AddrPort{})
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
if len(h.Super.Messages) != 0 {
|
||||||
|
t.Fatal(h.Super.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can receive a data packet.
|
||||||
|
func TestConnReader_handleDataPacket(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := make([]byte, 1024)
|
||||||
|
rand.Read(pkt)
|
||||||
|
|
||||||
|
h.WRemote.SendDataPacket(pkt, h.Remote)
|
||||||
|
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
|
||||||
|
if len(h.IFace.Written) != 1 {
|
||||||
|
t.Fatal(h.IFace.Written)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(pkt, h.IFace.Written[0]) {
|
||||||
|
t.Fatal(h.IFace.Written)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that data packet is ignored if route isn't up.
|
||||||
|
func TestConnReader_handleDataPacket_routeDown(t *testing.T) {
|
||||||
|
h := newConnReadeTestHarness()
|
||||||
|
|
||||||
|
pkt := make([]byte, 1024)
|
||||||
|
rand.Read(pkt)
|
||||||
|
|
||||||
|
h.WRemote.SendDataPacket(pkt, h.Remote)
|
||||||
|
route := h.R.routes[2].Load()
|
||||||
|
route.Up = false
|
||||||
|
|
||||||
|
h.R.handleNextPacket()
|
||||||
|
|
||||||
|
if len(h.IFace.Written) != 0 {
|
||||||
|
t.Fatal(h.IFace.Written)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// Testing that a duplicate data packet is ignored.
|
||||||
|
|
||||||
|
// Testing that we send a relayed data packet.
|
||||||
|
|
||||||
|
// Testing that a relayed data packet is ignored if destination isn't up.
|
80
peer/connwriter.go
Normal file
80
peer/connwriter.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type connWriter struct {
|
||||||
|
localIP byte
|
||||||
|
conn udpWriter
|
||||||
|
|
||||||
|
// For sending control packets.
|
||||||
|
cBuf1 []byte
|
||||||
|
cBuf2 []byte
|
||||||
|
|
||||||
|
// For sending data packets.
|
||||||
|
dBuf1 []byte
|
||||||
|
dBuf2 []byte
|
||||||
|
|
||||||
|
// Lock around for sending on UDP Conn.
|
||||||
|
wLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnWriter(conn udpWriter, localIP byte) *connWriter {
|
||||||
|
w := &connWriter{
|
||||||
|
localIP: localIP,
|
||||||
|
conn: conn,
|
||||||
|
cBuf1: make([]byte, bufferSize),
|
||||||
|
cBuf2: make([]byte, bufferSize),
|
||||||
|
dBuf1: make([]byte, bufferSize),
|
||||||
|
dBuf2: make([]byte, bufferSize),
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not safe for concurrent use. Should only be called by supervisor.
|
||||||
|
func (w *connWriter) SendControlPacket(pkt marshaller, route *peerRoute) {
|
||||||
|
enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2)
|
||||||
|
w.writeTo(enc, route.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relay control packet. Route must not be nil.
|
||||||
|
func (w *connWriter) RelayControlPacket(pkt marshaller, route, relay *peerRoute) {
|
||||||
|
enc := encryptControlPacket(w.localIP, route, pkt, w.cBuf1, w.cBuf2)
|
||||||
|
enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.cBuf1)
|
||||||
|
w.writeTo(enc, relay.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not safe for concurrent use. Should only be called by ifReader.
|
||||||
|
func (w *connWriter) SendDataPacket(pkt []byte, route *peerRoute) {
|
||||||
|
enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1)
|
||||||
|
w.writeTo(enc, route.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relay a data packet. Route must not be nil.
|
||||||
|
func (w *connWriter) RelayDataPacket(pkt []byte, route, relay *peerRoute) {
|
||||||
|
enc := encryptDataPacket(w.localIP, route.IP, route, pkt, w.dBuf1)
|
||||||
|
enc = encryptDataPacket(w.localIP, route.IP, relay, enc, w.dBuf2)
|
||||||
|
w.writeTo(enc, relay.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safe for concurrent use. Should only be called by connReader.
|
||||||
|
//
|
||||||
|
// This function will send pkt to the peer directly. This is used when a peer
|
||||||
|
// is acting as a relay and is forwarding already encrypted data for another
|
||||||
|
// peer.
|
||||||
|
func (w *connWriter) SendEncryptedDataPacket(pkt []byte, route *peerRoute) {
|
||||||
|
w.writeTo(pkt, route.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *connWriter) writeTo(packet []byte, addr netip.AddrPort) {
|
||||||
|
w.wLock.Lock()
|
||||||
|
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
|
||||||
|
log.Printf("[ConnWriter] Failed to write to UDP port: %v", err)
|
||||||
|
}
|
||||||
|
w.wLock.Unlock()
|
||||||
|
}
|
240
peer/connwriter_test.go
Normal file
240
peer/connwriter_test.go
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type testUDPPacket struct {
|
||||||
|
Addr netip.AddrPort
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type testUDPAddrPortWriter struct {
|
||||||
|
written []testUDPPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testUDPAddrPortWriter) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
|
||||||
|
w.written = append(w.written, testUDPPacket{
|
||||||
|
Addr: addr,
|
||||||
|
Data: bytes.Clone(b),
|
||||||
|
})
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testUDPAddrPortWriter) Written() []testUDPPacket {
|
||||||
|
out := w.written
|
||||||
|
w.written = []testUDPPacket{}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type testPacket string
|
||||||
|
|
||||||
|
func (p testPacket) Marshal(b []byte) []byte {
|
||||||
|
b = b[:len(p)]
|
||||||
|
copy(b, []byte(p))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func testConnWriter_getTestRoutes() (local, remote, relayLocal, relayRemote *peerRoute) {
|
||||||
|
localKeys := generateKeys()
|
||||||
|
remoteKeys := generateKeys()
|
||||||
|
|
||||||
|
local = newPeerRoute(2)
|
||||||
|
local.Up = true
|
||||||
|
local.Relay = false
|
||||||
|
local.PubSignKey = remoteKeys.PubSignKey
|
||||||
|
local.ControlCipher = newControlCipher(localKeys.PrivKey, remoteKeys.PubKey)
|
||||||
|
local.DataCipher = newDataCipher()
|
||||||
|
local.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 2}), 100)
|
||||||
|
|
||||||
|
remote = newPeerRoute(1)
|
||||||
|
remote.Up = true
|
||||||
|
remote.Relay = false
|
||||||
|
remote.PubSignKey = localKeys.PubSignKey
|
||||||
|
remote.ControlCipher = newControlCipher(remoteKeys.PrivKey, localKeys.PubKey)
|
||||||
|
remote.DataCipher = local.DataCipher
|
||||||
|
remote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)
|
||||||
|
|
||||||
|
rLocalKeys := generateKeys()
|
||||||
|
rRemoteKeys := generateKeys()
|
||||||
|
|
||||||
|
relayLocal = newPeerRoute(3)
|
||||||
|
relayLocal.Up = true
|
||||||
|
relayLocal.Relay = true
|
||||||
|
relayLocal.Direct = true
|
||||||
|
relayLocal.PubSignKey = rRemoteKeys.PubSignKey
|
||||||
|
relayLocal.ControlCipher = newControlCipher(rLocalKeys.PrivKey, rRemoteKeys.PubKey)
|
||||||
|
relayLocal.DataCipher = newDataCipher()
|
||||||
|
relayLocal.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 3}), 100)
|
||||||
|
|
||||||
|
relayRemote = newPeerRoute(1)
|
||||||
|
relayRemote.Up = true
|
||||||
|
relayRemote.Relay = false
|
||||||
|
relayRemote.Direct = true
|
||||||
|
relayRemote.PubSignKey = rLocalKeys.PubSignKey
|
||||||
|
relayRemote.ControlCipher = newControlCipher(rRemoteKeys.PrivKey, rLocalKeys.PubKey)
|
||||||
|
relayRemote.DataCipher = relayLocal.DataCipher
|
||||||
|
relayRemote.RemoteAddr = netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 100)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Testing if we can send a control packet directly to the remote route.
|
||||||
|
func TestConnWriter_SendControlPacket_direct(t *testing.T) {
|
||||||
|
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
||||||
|
route.Direct = true
|
||||||
|
|
||||||
|
writer := &testUDPAddrPortWriter{}
|
||||||
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
|
in := testPacket("hello world!")
|
||||||
|
|
||||||
|
w.SendControlPacket(in, route)
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out[0].Addr != route.RemoteAddr {
|
||||||
|
t.Fatal(out[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, ok := rRoute.ControlCipher.Decrypt(out[0].Data, make([]byte, 1024))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
if string(dec) != string(in) {
|
||||||
|
t.Fatal(dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing if we can relay a packet via an intermediary.
|
||||||
|
func TestConnWriter_RelayControlPacket_relay(t *testing.T) {
|
||||||
|
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
|
||||||
|
|
||||||
|
writer := &testUDPAddrPortWriter{}
|
||||||
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
|
in := testPacket("hello world!")
|
||||||
|
|
||||||
|
w.RelayControlPacket(in, route, relay)
|
||||||
|
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out[0].Addr != relay.RemoteAddr {
|
||||||
|
t.Fatal(out[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
dec2, ok := rRoute.ControlCipher.Decrypt(dec, make([]byte, 1024))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(dec2) != string(in) {
|
||||||
|
t.Fatal(dec2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can send a data packet directly to a remote route.
|
||||||
|
func TestConnWriter_SendDataPacket_direct(t *testing.T) {
|
||||||
|
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
||||||
|
route.Direct = true
|
||||||
|
|
||||||
|
writer := &testUDPAddrPortWriter{}
|
||||||
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
|
|
||||||
|
in := []byte("hello world!")
|
||||||
|
w.SendDataPacket(in, route)
|
||||||
|
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out[0].Addr != route.RemoteAddr {
|
||||||
|
t.Fatal(out[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, ok := rRoute.DataCipher.Decrypt(out[0].Data, make([]byte, 1024))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(dec, in) {
|
||||||
|
t.Fatal(dec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can relay a data packet via a relay.
|
||||||
|
func TestConnWriter_RelayDataPacket_relay(t *testing.T) {
|
||||||
|
route, rRoute, relay, rRelay := testConnWriter_getTestRoutes()
|
||||||
|
|
||||||
|
writer := &testUDPAddrPortWriter{}
|
||||||
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
|
in := []byte("Hello world!")
|
||||||
|
|
||||||
|
w.RelayDataPacket(in, route, relay)
|
||||||
|
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out[0].Addr != relay.RemoteAddr {
|
||||||
|
t.Fatal(out[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, ok := rRelay.DataCipher.Decrypt(out[0].Data, make([]byte, 1024))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
dec2, ok := rRoute.DataCipher.Decrypt(dec, make([]byte, 1024))
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(dec2, in) {
|
||||||
|
t.Fatal(dec2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can send an already encrypted packet.
|
||||||
|
func TestConnWriter_SendEncryptedDataPacket(t *testing.T) {
|
||||||
|
route, rRoute, _, _ := testConnWriter_getTestRoutes()
|
||||||
|
|
||||||
|
writer := &testUDPAddrPortWriter{}
|
||||||
|
w := newConnWriter(writer, rRoute.IP)
|
||||||
|
in := []byte("Hello world!")
|
||||||
|
|
||||||
|
w.SendEncryptedDataPacket(in, route)
|
||||||
|
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
if out[0].Addr != route.RemoteAddr {
|
||||||
|
t.Fatal(out[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(out[0].Data, in) {
|
||||||
|
t.Fatal(out[0])
|
||||||
|
}
|
||||||
|
}
|
58
peer/controlmessage.go
Normal file
58
peer/controlmessage.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"vppn/m"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type controlMsg[T any] struct {
|
||||||
|
SrcIP byte
|
||||||
|
SrcAddr netip.AddrPort
|
||||||
|
// TODO: RecvdAt int64 // Unixmilli.
|
||||||
|
Packet T
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseControlMsg(srcIP byte, srcAddr netip.AddrPort, buf []byte) (any, error) {
|
||||||
|
switch buf[0] {
|
||||||
|
|
||||||
|
case packetTypeSyn:
|
||||||
|
packet, err := parseSynPacket(buf)
|
||||||
|
return controlMsg[synPacket]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
case packetTypeAck:
|
||||||
|
packet, err := parseAckPacket(buf)
|
||||||
|
return controlMsg[ackPacket]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
case packetTypeProbe:
|
||||||
|
packet, err := parseProbePacket(buf)
|
||||||
|
return controlMsg[probePacket]{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
SrcAddr: srcAddr,
|
||||||
|
Packet: packet,
|
||||||
|
}, err
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errUnknownPacketType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type peerUpdateMsg struct {
|
||||||
|
PeerIP byte
|
||||||
|
Peer *m.Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type pingTimerMsg struct{}
|
113
peer/crypto.go
Normal file
113
peer/crypto.go
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"log"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/nacl/box"
|
||||||
|
"golang.org/x/crypto/nacl/sign"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cryptoKeys struct {
|
||||||
|
PubKey []byte
|
||||||
|
PrivKey []byte
|
||||||
|
PubSignKey []byte
|
||||||
|
PrivSignKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateKeys() cryptoKeys {
|
||||||
|
pubKey, privKey, err := box.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate encryption keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubSignKey, privSignKey, err := sign.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate signing keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cryptoKeys{pubKey[:], privKey[:], pubSignKey[:], privSignKey[:]}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Route must have a ControlCipher.
|
||||||
|
func encryptControlPacket(
|
||||||
|
localIP byte,
|
||||||
|
route *peerRoute,
|
||||||
|
pkt marshaller,
|
||||||
|
tmp []byte,
|
||||||
|
out []byte,
|
||||||
|
) []byte {
|
||||||
|
h := header{
|
||||||
|
StreamID: controlStreamID,
|
||||||
|
Counter: atomic.AddUint64(route.Counter, 1),
|
||||||
|
SourceIP: localIP,
|
||||||
|
DestIP: route.IP,
|
||||||
|
}
|
||||||
|
tmp = pkt.Marshal(tmp)
|
||||||
|
return route.ControlCipher.Encrypt(h, tmp, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a controlMsg[PacketType]. Route must have ControlCipher.
|
||||||
|
func decryptControlPacket(
|
||||||
|
route *peerRoute,
|
||||||
|
fromAddr netip.AddrPort,
|
||||||
|
h header,
|
||||||
|
encrypted []byte,
|
||||||
|
tmp []byte,
|
||||||
|
) (any, error) {
|
||||||
|
out, ok := route.ControlCipher.Decrypt(encrypted, tmp)
|
||||||
|
if !ok {
|
||||||
|
return nil, errDecryptionFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.DupCheck.IsDup(h.Counter) {
|
||||||
|
return nil, errDuplicateSeqNum
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := parseControlMsg(h.SourceIP, fromAddr, out)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func encryptDataPacket(
|
||||||
|
localIP byte,
|
||||||
|
destIP byte,
|
||||||
|
route *peerRoute,
|
||||||
|
data []byte,
|
||||||
|
out []byte,
|
||||||
|
) []byte {
|
||||||
|
h := header{
|
||||||
|
StreamID: dataStreamID,
|
||||||
|
Counter: atomic.AddUint64(route.Counter, 1),
|
||||||
|
SourceIP: localIP,
|
||||||
|
DestIP: destIP,
|
||||||
|
}
|
||||||
|
return route.DataCipher.Encrypt(h, data, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptDataPacket(
|
||||||
|
route *peerRoute,
|
||||||
|
h header,
|
||||||
|
encrypted []byte,
|
||||||
|
out []byte,
|
||||||
|
) ([]byte, error) {
|
||||||
|
dec, ok := route.DataCipher.Decrypt(encrypted, out)
|
||||||
|
if !ok {
|
||||||
|
return nil, errDecryptionFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.DupCheck.IsDup(h.Counter) {
|
||||||
|
return nil, errDuplicateSeqNum
|
||||||
|
}
|
||||||
|
|
||||||
|
return dec, nil
|
||||||
|
}
|
213
peer/crypto_test.go
Normal file
213
peer/crypto_test.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newRoutePairForTesting() (*peerRoute, *peerRoute) {
|
||||||
|
keys1 := generateKeys()
|
||||||
|
keys2 := generateKeys()
|
||||||
|
|
||||||
|
r1 := newPeerRoute(1)
|
||||||
|
r1.PubSignKey = keys1.PubSignKey
|
||||||
|
r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey)
|
||||||
|
r1.DataCipher = newDataCipher()
|
||||||
|
|
||||||
|
r2 := newPeerRoute(2)
|
||||||
|
r2.PubSignKey = keys2.PubSignKey
|
||||||
|
r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey)
|
||||||
|
r2.DataCipher = r1.DataCipher
|
||||||
|
|
||||||
|
return r1, r2
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptControlPacket(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
tmp = make([]byte, bufferSize)
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
)
|
||||||
|
|
||||||
|
in := synPacket{
|
||||||
|
TraceID: newTraceID(),
|
||||||
|
SharedKey: r1.DataCipher.Key(),
|
||||||
|
Direct: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
iMsg, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := iMsg.(controlMsg[synPacket])
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(msg.Packet, in) {
|
||||||
|
t.Fatal(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptControlPacket_decryptionFailed(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
tmp = make([]byte, bufferSize)
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
)
|
||||||
|
|
||||||
|
in := synPacket{
|
||||||
|
TraceID: newTraceID(),
|
||||||
|
SharedKey: r1.DataCipher.Key(),
|
||||||
|
Direct: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range enc {
|
||||||
|
x := bytes.Clone(enc)
|
||||||
|
x[i]++
|
||||||
|
_, err := decryptControlPacket(r2, netip.AddrPort{}, h, x, tmp)
|
||||||
|
if !errors.Is(err, errDecryptionFailed) {
|
||||||
|
t.Fatal(i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptControlPacket_duplicate(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
tmp = make([]byte, bufferSize)
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
)
|
||||||
|
|
||||||
|
in := synPacket{
|
||||||
|
TraceID: newTraceID(),
|
||||||
|
SharedKey: r1.DataCipher.Key(),
|
||||||
|
Direct: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp)
|
||||||
|
if !errors.Is(err, errDuplicateSeqNum) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptControlPacket_invalidPacket(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
tmp = make([]byte, bufferSize)
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
)
|
||||||
|
|
||||||
|
in := testPacket("hello!")
|
||||||
|
|
||||||
|
enc := encryptControlPacket(r1.IP, r2, in, tmp, out)
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := decryptControlPacket(r2, netip.AddrPort{}, h, enc, tmp)
|
||||||
|
if !errors.Is(err, errUnknownPacketType) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptDataPacket(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
data = make([]byte, 1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
rand.Read(data)
|
||||||
|
|
||||||
|
enc := encryptDataPacket(r1.IP, r2.IP, r2, data, out)
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := decryptDataPacket(r1, h, bytes.Clone(enc), out)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(data, out) {
|
||||||
|
t.Fatal(data, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptDataPacket_incorrectCipher(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
data = make([]byte, 1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
rand.Read(data)
|
||||||
|
|
||||||
|
enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out))
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
r1.DataCipher = newDataCipher()
|
||||||
|
_, err := decryptDataPacket(r1, h, enc, bytes.Clone(out))
|
||||||
|
if !errors.Is(err, errDecryptionFailed) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptDataPacket_duplicate(t *testing.T) {
|
||||||
|
var (
|
||||||
|
r1, r2 = newRoutePairForTesting()
|
||||||
|
out = make([]byte, bufferSize)
|
||||||
|
data = make([]byte, 1024)
|
||||||
|
)
|
||||||
|
|
||||||
|
rand.Read(data)
|
||||||
|
|
||||||
|
enc := encryptDataPacket(r1.IP, r2.IP, r2, data, bytes.Clone(out))
|
||||||
|
h, ok := parseHeader(enc)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(h, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := decryptDataPacket(r1, h, enc, bytes.Clone(out))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = decryptDataPacket(r1, h, enc, bytes.Clone(out))
|
||||||
|
if !errors.Is(err, errDuplicateSeqNum) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
76
peer/dupcheck.go
Normal file
76
peer/dupcheck.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
type dupCheck struct {
|
||||||
|
bitSet
|
||||||
|
head int
|
||||||
|
tail int
|
||||||
|
headCounter uint64
|
||||||
|
tailCounter uint64 // Also next expected counter value.
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDupCheck(headCounter uint64) *dupCheck {
|
||||||
|
return &dupCheck{
|
||||||
|
headCounter: headCounter,
|
||||||
|
tailCounter: headCounter + 1,
|
||||||
|
tail: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dc *dupCheck) IsDup(counter uint64) bool {
|
||||||
|
|
||||||
|
// Before head => it's late, say it's a dup.
|
||||||
|
if counter < dc.headCounter {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's within the counter bounds.
|
||||||
|
if counter < dc.tailCounter {
|
||||||
|
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
|
||||||
|
if dc.Get(index) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.Set(index)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// It's more than 1 beyond the tail.
|
||||||
|
delta := counter - dc.tailCounter
|
||||||
|
|
||||||
|
// Full clear.
|
||||||
|
if delta >= bitSetSize-1 {
|
||||||
|
dc.ClearAll()
|
||||||
|
dc.Set(0)
|
||||||
|
|
||||||
|
dc.tail = 1
|
||||||
|
dc.head = 2
|
||||||
|
dc.tailCounter = counter + 1
|
||||||
|
dc.headCounter = dc.tailCounter - bitSetSize + 1
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear if necessary.
|
||||||
|
for i := 0; i < int(delta); i++ {
|
||||||
|
dc.put(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.put(true)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dc *dupCheck) put(set bool) {
|
||||||
|
if set {
|
||||||
|
dc.Set(dc.tail)
|
||||||
|
} else {
|
||||||
|
dc.Clear(dc.tail)
|
||||||
|
}
|
||||||
|
|
||||||
|
dc.tail = (dc.tail + 1) % bitSetSize
|
||||||
|
dc.tailCounter++
|
||||||
|
|
||||||
|
if dc.head == dc.tail {
|
||||||
|
dc.head = (dc.head + 1) % bitSetSize
|
||||||
|
dc.headCounter++
|
||||||
|
}
|
||||||
|
}
|
57
peer/dupcheck_test.go
Normal file
57
peer/dupcheck_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDupCheck(t *testing.T) {
|
||||||
|
dc := newDupCheck(0)
|
||||||
|
|
||||||
|
for i := range bitSetSize {
|
||||||
|
if dc.IsDup(uint64(i)) {
|
||||||
|
t.Fatal("!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
Counter uint64
|
||||||
|
Dup bool
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []TestCase{
|
||||||
|
{511, true},
|
||||||
|
{0, true},
|
||||||
|
{1, true},
|
||||||
|
{2, true},
|
||||||
|
{3, true},
|
||||||
|
{63, true},
|
||||||
|
{256, true},
|
||||||
|
{510, true},
|
||||||
|
{511, true},
|
||||||
|
{512, false},
|
||||||
|
{0, true},
|
||||||
|
{512, true},
|
||||||
|
{513, false},
|
||||||
|
{517, false},
|
||||||
|
{512, true},
|
||||||
|
{513, true},
|
||||||
|
{514, false},
|
||||||
|
{515, false},
|
||||||
|
{516, false},
|
||||||
|
{517, true},
|
||||||
|
{2512, false},
|
||||||
|
{2512, true},
|
||||||
|
{2001, true},
|
||||||
|
{2002, false},
|
||||||
|
{2002, true},
|
||||||
|
{4000, false},
|
||||||
|
{4000 - 511, true}, // Too old.
|
||||||
|
{4000 - 510, false}, // Just in the window.
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
if ok := dc.IsDup(tc.Counter); ok != tc.Dup {
|
||||||
|
t.Fatal(i, ok, tc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
10
peer/errors.go
Normal file
10
peer/errors.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
errDecryptionFailed = errors.New("decryption failed")
|
||||||
|
errDuplicateSeqNum = errors.New("duplicate sequence number")
|
||||||
|
errMalformedPacket = errors.New("malformed packet")
|
||||||
|
errUnknownPacketType = errors.New("unknown packet type")
|
||||||
|
)
|
19
peer/globals.go
Normal file
19
peer/globals.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bufferSize = 1536
|
||||||
|
if_mtu = 1200
|
||||||
|
if_queue_len = 2048
|
||||||
|
controlCipherOverhead = 16
|
||||||
|
dataCipherOverhead = 16
|
||||||
|
signOverhead = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
|
||||||
|
netip.AddrFrom4([4]byte{224, 0, 0, 157}),
|
||||||
|
4560))
|
49
peer/header.go
Normal file
49
peer/header.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
const (
|
||||||
|
headerSize = 12
|
||||||
|
controlStreamID = 2
|
||||||
|
controlHeaderSize = 24
|
||||||
|
dataStreamID = 1
|
||||||
|
dataHeaderSize = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
type header struct {
|
||||||
|
Version byte
|
||||||
|
StreamID byte
|
||||||
|
SourceIP byte
|
||||||
|
DestIP byte
|
||||||
|
Counter uint64 // Init with time.Now().Unix << 30 to ensure monotonic.
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeader(b []byte) (h header, ok bool) {
|
||||||
|
if len(b) < headerSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.Version = b[0]
|
||||||
|
h.StreamID = b[1]
|
||||||
|
h.SourceIP = b[2]
|
||||||
|
h.DestIP = b[3]
|
||||||
|
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
|
||||||
|
return h, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *header) Parse(b []byte) {
|
||||||
|
h.Version = b[0]
|
||||||
|
h.StreamID = b[1]
|
||||||
|
h.SourceIP = b[2]
|
||||||
|
h.DestIP = b[3]
|
||||||
|
h.Counter = *(*uint64)(unsafe.Pointer(&b[4]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *header) Marshal(buf []byte) {
|
||||||
|
buf[0] = h.Version
|
||||||
|
buf[1] = h.StreamID
|
||||||
|
buf[2] = h.SourceIP
|
||||||
|
buf[3] = h.DestIP
|
||||||
|
*(*uint64)(unsafe.Pointer(&buf[4])) = h.Counter
|
||||||
|
}
|
21
peer/header_test.go
Normal file
21
peer/header_test.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestHeaderMarshalParse(t *testing.T) {
|
||||||
|
nIn := header{
|
||||||
|
StreamID: 23,
|
||||||
|
Counter: 3212,
|
||||||
|
SourceIP: 34,
|
||||||
|
DestIP: 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
nIn.Marshal(buf)
|
||||||
|
|
||||||
|
nOut := header{}
|
||||||
|
nOut.Parse(buf)
|
||||||
|
if nIn != nOut {
|
||||||
|
t.Fatal(nIn, nOut)
|
||||||
|
}
|
||||||
|
}
|
100
peer/ifreader.go
Normal file
100
peer/ifreader.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ifReader struct {
|
||||||
|
iface io.Reader
|
||||||
|
routes [256]*atomic.Pointer[peerRoute]
|
||||||
|
relay *atomic.Pointer[peerRoute]
|
||||||
|
sender dataPacketSender
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIFReader(
|
||||||
|
iface io.Reader,
|
||||||
|
routes [256]*atomic.Pointer[peerRoute],
|
||||||
|
relay *atomic.Pointer[peerRoute],
|
||||||
|
sender dataPacketSender,
|
||||||
|
) *ifReader {
|
||||||
|
return &ifReader{
|
||||||
|
iface: iface,
|
||||||
|
routes: routes,
|
||||||
|
relay: relay,
|
||||||
|
sender: sender,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ifReader) Run() {
|
||||||
|
var (
|
||||||
|
packet = make([]byte, bufferSize)
|
||||||
|
remoteIP byte
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
|
||||||
|
for {
|
||||||
|
packet = r.readNextPacket(packet)
|
||||||
|
if remoteIP, ok = r.parsePacket(packet); ok {
|
||||||
|
r.sendPacket(packet, remoteIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ifReader) sendPacket(pkt []byte, remoteIP byte) {
|
||||||
|
route := r.routes[remoteIP].Load()
|
||||||
|
if !route.Up {
|
||||||
|
log.Printf("Route not connected: %d", remoteIP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct path => early return.
|
||||||
|
if route.Direct {
|
||||||
|
r.sender.SendDataPacket(pkt, route)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if relay := r.relay.Load(); relay != nil && relay.Up {
|
||||||
|
r.sender.RelayDataPacket(pkt, route, relay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get next packet, returning packet, and destination ip.
|
||||||
|
func (r *ifReader) readNextPacket(buf []byte) []byte {
|
||||||
|
n, err := r.iface.Read(buf[:cap(buf)])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read from interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ifReader) parsePacket(buf []byte) (byte, bool) {
|
||||||
|
n := len(buf)
|
||||||
|
if n == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
version := buf[0] >> 4
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 4:
|
||||||
|
if n < 20 {
|
||||||
|
log.Printf("Short IPv4 packet: %d", len(buf))
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return buf[19], true
|
||||||
|
|
||||||
|
case 6:
|
||||||
|
if len(buf) < 40 {
|
||||||
|
log.Printf("Short IPv6 packet: %d", len(buf))
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return buf[39], true
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Printf("Invalid IP packet version: %v", version)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
232
peer/ifreader_test.go
Normal file
232
peer/ifreader_test.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test that we parse IPv4 packets correctly.
|
||||||
|
func TestIFReader_parsePacket_ipv4(t *testing.T) {
|
||||||
|
r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 1234)
|
||||||
|
pkt[0] = 4 << 4
|
||||||
|
pkt[19] = 128
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); !ok || ip != 128 {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that we parse IPv6 packets correctly.
|
||||||
|
func TestIFReader_parsePacket_ipv6(t *testing.T) {
|
||||||
|
r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 1234)
|
||||||
|
pkt[0] = 6 << 4
|
||||||
|
pkt[39] = 42
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); !ok || ip != 42 {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that empty packets work as expected.
|
||||||
|
func TestIFReader_parsePacket_emptyPacket(t *testing.T) {
|
||||||
|
r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 0)
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that invalid IP versions fail.
|
||||||
|
func TestIFReader_parsePacket_invalidIPVersion(t *testing.T) {
|
||||||
|
r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
|
||||||
|
for i := byte(1); i < 16; i++ {
|
||||||
|
if i == 4 || i == 6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pkt := make([]byte, 1234)
|
||||||
|
pkt[0] = i << 4
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(i, ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that short IPv4 packets fail.
|
||||||
|
func TestIFReader_parsePacket_shortIPv4(t *testing.T) {
|
||||||
|
r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 19)
|
||||||
|
pkt[0] = 4 << 4
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that short IPv6 packets fail.
|
||||||
|
func TestIFReader_parsePacket_shortIPv6(t *testing.T) {
|
||||||
|
r := newIFReader(nil, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
|
||||||
|
pkt := make([]byte, 39)
|
||||||
|
pkt[0] = 6 << 4
|
||||||
|
|
||||||
|
if ip, ok := r.parsePacket(pkt); ok {
|
||||||
|
t.Fatal(ip, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that we can read a packet.
|
||||||
|
func TestIFReader_readNextpacket(t *testing.T) {
|
||||||
|
in, out := net.Pipe()
|
||||||
|
r := newIFReader(out, [256]*atomic.Pointer[peerRoute]{}, nil, nil)
|
||||||
|
defer in.Close()
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
go in.Write([]byte("hello world!"))
|
||||||
|
|
||||||
|
pkt := r.readNextPacket(make([]byte, bufferSize))
|
||||||
|
if !bytes.Equal(pkt, []byte("hello world!")) {
|
||||||
|
t.Fatalf("%s", pkt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type sentPacket struct {
|
||||||
|
Relayed bool
|
||||||
|
Packet []byte
|
||||||
|
Route peerRoute
|
||||||
|
Relay peerRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
type sendPacketTestHarness struct {
|
||||||
|
Packets []sentPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sendPacketTestHarness) SendDataPacket(pkt []byte, route *peerRoute) {
|
||||||
|
h.Packets = append(h.Packets, sentPacket{
|
||||||
|
Packet: bytes.Clone(pkt),
|
||||||
|
Route: *route,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *sendPacketTestHarness) RelayDataPacket(pkt []byte, route, relay *peerRoute) {
|
||||||
|
h.Packets = append(h.Packets, sentPacket{
|
||||||
|
Relayed: true,
|
||||||
|
Packet: bytes.Clone(pkt),
|
||||||
|
Route: *route,
|
||||||
|
Relay: *relay,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIFReaderForSendPacketTesting() (*ifReader, *sendPacketTestHarness) {
|
||||||
|
h := &sendPacketTestHarness{}
|
||||||
|
|
||||||
|
routes := [256]*atomic.Pointer[peerRoute]{}
|
||||||
|
for i := range routes {
|
||||||
|
routes[i] = &atomic.Pointer[peerRoute]{}
|
||||||
|
routes[i].Store(&peerRoute{})
|
||||||
|
}
|
||||||
|
relay := &atomic.Pointer[peerRoute]{}
|
||||||
|
r := newIFReader(nil, routes, relay, h)
|
||||||
|
return r, h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can send a packet directly.
|
||||||
|
func TestIFReader_sendPacket_direct(t *testing.T) {
|
||||||
|
r, h := newIFReaderForSendPacketTesting()
|
||||||
|
|
||||||
|
route := r.routes[2].Load()
|
||||||
|
route.Up = true
|
||||||
|
route.Direct = true
|
||||||
|
|
||||||
|
in := []byte("hello world")
|
||||||
|
|
||||||
|
r.sendPacket(in, 2)
|
||||||
|
if len(h.Packets) != 1 {
|
||||||
|
t.Fatal(h.Packets)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := sentPacket{
|
||||||
|
Relayed: false,
|
||||||
|
Packet: in,
|
||||||
|
Route: *route,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(h.Packets[0], expected) {
|
||||||
|
t.Fatal(h.Packets[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we don't send a packet if route isn't up.
|
||||||
|
func TestIFReader_sendPacket_directNotUp(t *testing.T) {
|
||||||
|
r, h := newIFReaderForSendPacketTesting()
|
||||||
|
|
||||||
|
route := r.routes[2].Load()
|
||||||
|
route.Direct = true
|
||||||
|
|
||||||
|
in := []byte("hello world")
|
||||||
|
|
||||||
|
r.sendPacket(in, 2)
|
||||||
|
if len(h.Packets) != 0 {
|
||||||
|
t.Fatal(h.Packets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we can send a packet via a relay.
|
||||||
|
func TestIFReader_sendPacket_relayed(t *testing.T) {
|
||||||
|
r, h := newIFReaderForSendPacketTesting()
|
||||||
|
|
||||||
|
route := r.routes[2].Load()
|
||||||
|
route.Up = true
|
||||||
|
route.Direct = false
|
||||||
|
|
||||||
|
relay := r.routes[3].Load()
|
||||||
|
r.relay.Store(relay)
|
||||||
|
relay.Up = true
|
||||||
|
relay.Direct = true
|
||||||
|
|
||||||
|
in := []byte("hello world")
|
||||||
|
|
||||||
|
r.sendPacket(in, 2)
|
||||||
|
if len(h.Packets) != 1 {
|
||||||
|
t.Fatal(h.Packets)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := sentPacket{
|
||||||
|
Relayed: true,
|
||||||
|
Packet: in,
|
||||||
|
Route: *route,
|
||||||
|
Relay: *relay,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(h.Packets[0], expected) {
|
||||||
|
t.Fatal(h.Packets[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we don't try to send on a nil relay IP.
|
||||||
|
func TestIFReader_sendPacket_nilRealy(t *testing.T) {
|
||||||
|
r, h := newIFReaderForSendPacketTesting()
|
||||||
|
|
||||||
|
route := r.routes[2].Load()
|
||||||
|
route.Up = true
|
||||||
|
route.Direct = false
|
||||||
|
|
||||||
|
in := []byte("hello world")
|
||||||
|
|
||||||
|
r.sendPacket(in, 2)
|
||||||
|
if len(h.Packets) != 0 {
|
||||||
|
t.Fatal(h.Packets)
|
||||||
|
}
|
||||||
|
}
|
5
peer/ifwriter.go
Normal file
5
peer/ifwriter.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
type ifWriter io.Writer
|
28
peer/interfaces.go
Normal file
28
peer/interfaces.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
type udpReader interface {
|
||||||
|
ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpWriter interface {
|
||||||
|
WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type marshaller interface {
|
||||||
|
Marshal([]byte) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type dataPacketSender interface {
|
||||||
|
SendDataPacket(pkt []byte, route *peerRoute)
|
||||||
|
RelayDataPacket(pkt []byte, route, relay *peerRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
type encryptedPacketSender interface {
|
||||||
|
SendEncryptedDataPacket(pkt []byte, route *peerRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
type controlMsgHandler interface {
|
||||||
|
HandleControlMsg(pkt any)
|
||||||
|
}
|
62
peer/mcwriter.go
Normal file
62
peer/mcwriter.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/nacl/sign"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type mcUDPWriter interface {
|
||||||
|
WriteToUDP([]byte, *net.UDPAddr) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func createLocalDiscoveryPacket(localIP byte, signingKey []byte) []byte {
|
||||||
|
h := header{
|
||||||
|
SourceIP: localIP,
|
||||||
|
DestIP: 255,
|
||||||
|
}
|
||||||
|
buf := make([]byte, headerSize)
|
||||||
|
h.Marshal(buf)
|
||||||
|
out := make([]byte, headerSize+signOverhead)
|
||||||
|
return sign.Sign(out[:0], buf, (*[64]byte)(signingKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerFromLocalDiscoveryPacket(pkt []byte) (h header, ok bool) {
|
||||||
|
if len(pkt) != headerSize+signOverhead {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Parse(pkt[signOverhead:])
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool {
|
||||||
|
_, ok := sign.Open(buf[:0], pkt, (*[32]byte)(pubSignKey))
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type mcWriter struct {
|
||||||
|
conn mcUDPWriter
|
||||||
|
discoveryPacket []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMCWriter(conn mcUDPWriter, localIP byte, signingKey []byte) *mcWriter {
|
||||||
|
return &mcWriter{
|
||||||
|
conn: conn,
|
||||||
|
discoveryPacket: createLocalDiscoveryPacket(localIP, signingKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *mcWriter) SendLocalDiscovery() {
|
||||||
|
if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil {
|
||||||
|
log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err)
|
||||||
|
}
|
||||||
|
}
|
102
peer/mcwriter_test.go
Normal file
102
peer/mcwriter_test.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Testing that we can create and verify a local discovery packet.
|
||||||
|
func TestVerifyLocalDiscoveryPacket_valid(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
|
||||||
|
|
||||||
|
header, ok := headerFromLocalDiscoveryPacket(created)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
if header.SourceIP != 55 || header.DestIP != 255 {
|
||||||
|
t.Fatal(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !verifyLocalDiscoveryPacket(created, make([]byte, 1024), keys.PubSignKey) {
|
||||||
|
t.Fatal("Not valid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that we don't try to parse short packets.
|
||||||
|
func TestVerifyLocalDiscoveryPacket_tooShort(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
|
||||||
|
|
||||||
|
_, ok := headerFromLocalDiscoveryPacket(created[:len(created)-1])
|
||||||
|
if ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Testing that modifying a packet makes it invalid.
|
||||||
|
func TestVerifyLocalDiscoveryPacket_invalid(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
|
||||||
|
created := createLocalDiscoveryPacket(55, keys.PrivSignKey)
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for i := range created {
|
||||||
|
modified := bytes.Clone(created)
|
||||||
|
modified[i]++
|
||||||
|
if verifyLocalDiscoveryPacket(modified, buf, keys.PubSignKey) {
|
||||||
|
t.Fatal("Verification should have failed.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type testUDPWriter struct {
|
||||||
|
written [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testUDPWriter) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
|
||||||
|
w.written = append(w.written, bytes.Clone(b))
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *testUDPWriter) Written() [][]byte {
|
||||||
|
out := w.written
|
||||||
|
w.written = [][]byte{}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Testing that the mcWriter sends local discovery packets as expected.
|
||||||
|
func TestMCWriter_SendLocalDiscovery(t *testing.T) {
|
||||||
|
keys := generateKeys()
|
||||||
|
writer := &testUDPWriter{}
|
||||||
|
|
||||||
|
mcw := newMCWriter(writer, 42, keys.PrivSignKey)
|
||||||
|
mcw.SendLocalDiscovery()
|
||||||
|
|
||||||
|
out := writer.Written()
|
||||||
|
if len(out) != 1 {
|
||||||
|
t.Fatal(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := out[0]
|
||||||
|
|
||||||
|
header, ok := headerFromLocalDiscoveryPacket(pkt)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal(ok)
|
||||||
|
}
|
||||||
|
if header.SourceIP != 42 || header.DestIP != 255 {
|
||||||
|
t.Fatal(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !verifyLocalDiscoveryPacket(pkt, make([]byte, 1024), keys.PubSignKey) {
|
||||||
|
t.Fatal("Verification should succeed.")
|
||||||
|
}
|
||||||
|
}
|
190
peer/packets-util.go
Normal file
190
peer/packets-util.go
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var traceIDCounter uint64 = uint64(time.Now().Unix()<<30) + 1
|
||||||
|
|
||||||
|
func newTraceID() uint64 {
|
||||||
|
return atomic.AddUint64(&traceIDCounter, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type binWriter struct {
|
||||||
|
b []byte
|
||||||
|
i int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinWriter(buf []byte) *binWriter {
|
||||||
|
buf = buf[:cap(buf)]
|
||||||
|
return &binWriter{buf, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Bool(b bool) *binWriter {
|
||||||
|
if b {
|
||||||
|
return w.Byte(1)
|
||||||
|
}
|
||||||
|
return w.Byte(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Byte(b byte) *binWriter {
|
||||||
|
w.b[w.i] = b
|
||||||
|
w.i++
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) SharedKey(key [32]byte) *binWriter {
|
||||||
|
copy(w.b[w.i:w.i+32], key[:])
|
||||||
|
w.i += 32
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Uint16(x uint16) *binWriter {
|
||||||
|
*(*uint16)(unsafe.Pointer(&w.b[w.i])) = x
|
||||||
|
w.i += 2
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Uint64(x uint64) *binWriter {
|
||||||
|
*(*uint64)(unsafe.Pointer(&w.b[w.i])) = x
|
||||||
|
w.i += 8
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Int64(x int64) *binWriter {
|
||||||
|
*(*int64)(unsafe.Pointer(&w.b[w.i])) = x
|
||||||
|
w.i += 8
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) AddrPort(addrPort netip.AddrPort) *binWriter {
|
||||||
|
w.Bool(addrPort.IsValid())
|
||||||
|
addr := addrPort.Addr().As16()
|
||||||
|
copy(w.b[w.i:w.i+16], addr[:])
|
||||||
|
w.i += 16
|
||||||
|
return w.Uint16(addrPort.Port())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) AddrPortArray(l [8]netip.AddrPort) *binWriter {
|
||||||
|
for _, addrPort := range l {
|
||||||
|
w.AddrPort(addrPort)
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *binWriter) Build() []byte {
|
||||||
|
return w.b[:w.i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type binReader struct {
|
||||||
|
b []byte
|
||||||
|
i int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinReader(buf []byte) *binReader {
|
||||||
|
return &binReader{b: buf}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) hasBytes(n int) bool {
|
||||||
|
if r.err != nil || (len(r.b)-r.i) < n {
|
||||||
|
r.err = errMalformedPacket
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Bool(b *bool) *binReader {
|
||||||
|
var bb byte
|
||||||
|
r.Byte(&bb)
|
||||||
|
*b = bb != 0
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Byte(b *byte) *binReader {
|
||||||
|
if !r.hasBytes(1) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*b = r.b[r.i]
|
||||||
|
r.i++
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) SharedKey(x *[32]byte) *binReader {
|
||||||
|
if !r.hasBytes(32) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = ([32]byte)(r.b[r.i : r.i+32])
|
||||||
|
r.i += 32
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Uint16(x *uint16) *binReader {
|
||||||
|
if !r.hasBytes(2) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = *(*uint16)(unsafe.Pointer(&r.b[r.i]))
|
||||||
|
r.i += 2
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Uint64(x *uint64) *binReader {
|
||||||
|
if !r.hasBytes(8) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = *(*uint64)(unsafe.Pointer(&r.b[r.i]))
|
||||||
|
r.i += 8
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Int64(x *int64) *binReader {
|
||||||
|
if !r.hasBytes(8) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
*x = *(*int64)(unsafe.Pointer(&r.b[r.i]))
|
||||||
|
r.i += 8
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) AddrPort(x *netip.AddrPort) *binReader {
|
||||||
|
if !r.hasBytes(19) {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
valid bool
|
||||||
|
port uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
r.Bool(&valid)
|
||||||
|
addr := netip.AddrFrom16(([16]byte)(r.b[r.i : r.i+16])).Unmap()
|
||||||
|
r.i += 16
|
||||||
|
|
||||||
|
r.Uint16(&port)
|
||||||
|
|
||||||
|
if valid {
|
||||||
|
*x = netip.AddrPortFrom(addr, port)
|
||||||
|
} else {
|
||||||
|
*x = netip.AddrPort{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) AddrPortArray(x *[8]netip.AddrPort) *binReader {
|
||||||
|
for i := range x {
|
||||||
|
r.AddrPort(&x[i])
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *binReader) Error() error {
|
||||||
|
return r.err
|
||||||
|
}
|
56
peer/packets-util_test.go
Normal file
56
peer/packets-util_test.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBinWriteRead(t *testing.T) {
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
|
||||||
|
type Item struct {
|
||||||
|
Type byte
|
||||||
|
TraceID uint64
|
||||||
|
Addrs [8]netip.AddrPort
|
||||||
|
DestAddr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
in := Item{
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
[8]netip.AddrPort{},
|
||||||
|
netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 22),
|
||||||
|
}
|
||||||
|
|
||||||
|
in.Addrs[0] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{0, 1, 2, 3}), 20)
|
||||||
|
in.Addrs[2] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 22)
|
||||||
|
in.Addrs[3] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 3}), 23)
|
||||||
|
in.Addrs[4] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 4}), 24)
|
||||||
|
in.Addrs[5] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 5}), 25)
|
||||||
|
in.Addrs[6] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{2, 3, 4, 6}), 26)
|
||||||
|
in.Addrs[7] = netip.AddrPortFrom(netip.AddrFrom4([4]byte{7, 8, 9, 7}), 27)
|
||||||
|
|
||||||
|
buf = newBinWriter(buf).
|
||||||
|
Byte(in.Type).
|
||||||
|
Uint64(in.TraceID).
|
||||||
|
AddrPort(in.DestAddr).
|
||||||
|
AddrPortArray(in.Addrs).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
out := Item{}
|
||||||
|
|
||||||
|
err := newBinReader(buf).
|
||||||
|
Byte(&out.Type).
|
||||||
|
Uint64(&out.TraceID).
|
||||||
|
AddrPort(&out.DestAddr).
|
||||||
|
AddrPortArray(&out.Addrs).
|
||||||
|
Error()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(in, out) {
|
||||||
|
t.Fatal(in, out)
|
||||||
|
}
|
||||||
|
}
|
123
peer/packets.go
Normal file
123
peer/packets.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
packetTypeSyn = iota + 1
|
||||||
|
packetTypeSynAck
|
||||||
|
packetTypeAck
|
||||||
|
packetTypeProbe
|
||||||
|
packetTypeAddrDiscovery
|
||||||
|
)
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type synPacket struct {
|
||||||
|
TraceID uint64 // TraceID to match response w/ request.
|
||||||
|
// TODO: SentAt int64 // Unixmilli.
|
||||||
|
SharedKey [32]byte // Our shared key.
|
||||||
|
Direct bool
|
||||||
|
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p synPacket) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeSyn).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
SharedKey(p.SharedKey).
|
||||||
|
Bool(p.Direct).
|
||||||
|
AddrPort(p.PossibleAddrs[0]).
|
||||||
|
AddrPort(p.PossibleAddrs[1]).
|
||||||
|
AddrPort(p.PossibleAddrs[2]).
|
||||||
|
AddrPort(p.PossibleAddrs[3]).
|
||||||
|
AddrPort(p.PossibleAddrs[4]).
|
||||||
|
AddrPort(p.PossibleAddrs[5]).
|
||||||
|
AddrPort(p.PossibleAddrs[6]).
|
||||||
|
AddrPort(p.PossibleAddrs[7]).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSynPacket(buf []byte) (p synPacket, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
SharedKey(&p.SharedKey).
|
||||||
|
Bool(&p.Direct).
|
||||||
|
AddrPort(&p.PossibleAddrs[0]).
|
||||||
|
AddrPort(&p.PossibleAddrs[1]).
|
||||||
|
AddrPort(&p.PossibleAddrs[2]).
|
||||||
|
AddrPort(&p.PossibleAddrs[3]).
|
||||||
|
AddrPort(&p.PossibleAddrs[4]).
|
||||||
|
AddrPort(&p.PossibleAddrs[5]).
|
||||||
|
AddrPort(&p.PossibleAddrs[6]).
|
||||||
|
AddrPort(&p.PossibleAddrs[7]).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type ackPacket struct {
|
||||||
|
TraceID uint64
|
||||||
|
ToAddr netip.AddrPort
|
||||||
|
PossibleAddrs [8]netip.AddrPort // Possible public addresses of the sender.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p ackPacket) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeAck).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
AddrPort(p.ToAddr).
|
||||||
|
AddrPort(p.PossibleAddrs[0]).
|
||||||
|
AddrPort(p.PossibleAddrs[1]).
|
||||||
|
AddrPort(p.PossibleAddrs[2]).
|
||||||
|
AddrPort(p.PossibleAddrs[3]).
|
||||||
|
AddrPort(p.PossibleAddrs[4]).
|
||||||
|
AddrPort(p.PossibleAddrs[5]).
|
||||||
|
AddrPort(p.PossibleAddrs[6]).
|
||||||
|
AddrPort(p.PossibleAddrs[7]).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAckPacket(buf []byte) (p ackPacket, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
AddrPort(&p.ToAddr).
|
||||||
|
AddrPort(&p.PossibleAddrs[0]).
|
||||||
|
AddrPort(&p.PossibleAddrs[1]).
|
||||||
|
AddrPort(&p.PossibleAddrs[2]).
|
||||||
|
AddrPort(&p.PossibleAddrs[3]).
|
||||||
|
AddrPort(&p.PossibleAddrs[4]).
|
||||||
|
AddrPort(&p.PossibleAddrs[5]).
|
||||||
|
AddrPort(&p.PossibleAddrs[6]).
|
||||||
|
AddrPort(&p.PossibleAddrs[7]).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// A probeReqPacket is sent from a client to a server to determine if direct
|
||||||
|
// UDP communication can be used.
|
||||||
|
type probePacket struct {
|
||||||
|
TraceID uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p probePacket) Marshal(buf []byte) []byte {
|
||||||
|
return newBinWriter(buf).
|
||||||
|
Byte(packetTypeProbe).
|
||||||
|
Uint64(p.TraceID).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseProbePacket(buf []byte) (p probePacket, err error) {
|
||||||
|
err = newBinReader(buf[1:]).
|
||||||
|
Uint64(&p.TraceID).
|
||||||
|
Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type localDiscoveryPacket struct{}
|
1
peer/packets_test.go
Normal file
1
peer/packets_test.go
Normal file
@ -0,0 +1 @@
|
|||||||
|
package peer
|
29
peer/state.go
Normal file
29
peer/state.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type peerRoute struct {
|
||||||
|
IP byte // VPN IP of peer (last byte).
|
||||||
|
Up bool // True if data can be sent on the route.
|
||||||
|
Relay bool // True if the peer is a relay.
|
||||||
|
Direct bool // True if this is a direct connection.
|
||||||
|
PubSignKey []byte
|
||||||
|
ControlCipher *controlCipher
|
||||||
|
DataCipher *dataCipher
|
||||||
|
RemoteAddr netip.AddrPort // Remote address if directly connected.
|
||||||
|
|
||||||
|
Counter *uint64 // For sending to. Atomic access only.
|
||||||
|
DupCheck *dupCheck // For receiving from. Not safe for concurrent use.
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPeerRoute(ip byte) *peerRoute {
|
||||||
|
counter := uint64(time.Now().Unix()<<30 + 1)
|
||||||
|
return &peerRoute{
|
||||||
|
IP: ip,
|
||||||
|
Counter: &counter,
|
||||||
|
DupCheck: newDupCheck(0),
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user