sym-encryption #1

Merged
johnnylee merged 18 commits from sym-encryption into main 2024-12-24 18:37:44 +00:00
18 changed files with 664 additions and 115 deletions
Showing only changes of commit 8ab6158469 - Show all commits

View File

@ -2,6 +2,7 @@
## Roadmap ## Roadmap
* Rename Mediator -> Relay
* Node: use symmetric encryption after handshake * Node: use symmetric encryption after handshake
* AEAD-AES uses a 12 byte nonce. We need to shrink the header: * AEAD-AES uses a 12 byte nonce. We need to shrink the header:
* Remove Forward and replace it with a HeaderFlags bitfield. * Remove Forward and replace it with a HeaderFlags bitfield.

8
node/addrutil.go Normal file
View File

@ -0,0 +1,8 @@
package node
import "net/netip"
func addrIsValid(in []byte) bool {
_, ok := netip.AddrFromSlice(in)
return ok
}

26
node/cipher-control.go Normal file
View File

@ -0,0 +1,26 @@
package node
import "golang.org/x/crypto/nacl/box"
type controlCipher struct {
sharedKey [32]byte
}
func newControlCipher(privKey, pubKey []byte) *controlCipher {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
return &controlCipher{shared}
}
func (cc *controlCipher) Encrypt(h xHeader, data, out []byte) []byte {
const s = controlHeaderSize
out = out[:s+controlCipherOverhead+len(data)]
h.Marshal(out[:s])
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &cc.sharedKey)
return out
}
func (cc *controlCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = controlHeaderSize
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &cc.sharedKey)
}

View File

@ -3,12 +3,13 @@ package node
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"reflect"
"testing" "testing"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
) )
func newRoutingCipherForTesting() (c1, c2 routingCipher) { func newControlCipherForTesting() (c1, c2 *controlCipher) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader) pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil { if err != nil {
panic(err) panic(err)
@ -19,14 +20,14 @@ func newRoutingCipherForTesting() (c1, c2 routingCipher) {
panic(err) panic(err)
} }
return newRoutingCipher(privKey1[:], pubKey2[:]), return newControlCipher(privKey1[:], pubKey2[:]),
newRoutingCipher(privKey2[:], pubKey1[:]) newControlCipher(privKey2[:], pubKey1[:])
} }
func TestRoutingCipher(t *testing.T) { func TestControlCipher(t *testing.T) {
c1, c2 := newRoutingCipherForTesting() c1, c2 := newControlCipherForTesting()
maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) maxSizePlaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(maxSizePlaintext) rand.Read(maxSizePlaintext)
testCases := [][]byte{ testCases := [][]byte{
@ -40,6 +41,7 @@ func TestRoutingCipher(t *testing.T) {
for _, plaintext := range testCases { for _, plaintext := range testCases {
h1 := xHeader{ h1 := xHeader{
StreamID: controlStreamID,
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,
@ -49,6 +51,12 @@ func TestRoutingCipher(t *testing.T) {
encrypted = c1.Encrypt(h1, plaintext, encrypted) encrypted = c1.Encrypt(h1, plaintext, encrypted)
h2 := xHeader{}
h2.Parse(encrypted)
if !reflect.DeepEqual(h1, h2) {
t.Fatal(h1, h2)
}
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize)) decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
if !ok { if !ok {
t.Fatal(ok) t.Fatal(ok)
@ -60,9 +68,9 @@ func TestRoutingCipher(t *testing.T) {
} }
} }
func TestRoutingCipher_ShortCiphertext(t *testing.T) { func TestControlCipher_ShortCiphertext(t *testing.T) {
c1, _ := newRoutingCipherForTesting() c1, _ := newControlCipherForTesting()
shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1) shortText := make([]byte, controlHeaderSize+controlCipherOverhead-1)
rand.Read(shortText) rand.Read(shortText)
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize)) _, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
if ok { if ok {
@ -70,15 +78,15 @@ func TestRoutingCipher_ShortCiphertext(t *testing.T) {
} }
} }
func BenchmarkRoutingCipher_Encrypt(b *testing.B) { func BenchmarkControlCipher_Encrypt(b *testing.B) {
c1, _ := newRoutingCipherForTesting() c1, _ := newControlCipherForTesting()
h1 := xHeader{ h1 := xHeader{
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,
} }
plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext) rand.Read(plaintext)
encrypted := make([]byte, bufferSize) encrypted := make([]byte, bufferSize)
@ -89,8 +97,8 @@ func BenchmarkRoutingCipher_Encrypt(b *testing.B) {
} }
} }
func BenchmarkRoutingCipher_Decrypt(b *testing.B) { func BenchmarkControlCipher_Decrypt(b *testing.B) {
c1, c2 := newRoutingCipherForTesting() c1, c2 := newControlCipherForTesting()
h1 := xHeader{ h1 := xHeader{
Counter: 235153, Counter: 235153,
@ -98,7 +106,7 @@ func BenchmarkRoutingCipher_Decrypt(b *testing.B) {
DestIP: 88, DestIP: 88,
} }
plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead) plaintext := make([]byte, bufferSize-controlHeaderSize-controlCipherOverhead)
rand.Read(plaintext) rand.Read(plaintext)
encrypted := make([]byte, bufferSize) encrypted := make([]byte, bufferSize)

View File

@ -6,22 +6,23 @@ import (
"crypto/rand" "crypto/rand"
) )
// TODO: Use [32]byte for simplicity everywhere.
type dataCipher struct { type dataCipher struct {
key []byte key [32]byte
aead cipher.AEAD aead cipher.AEAD
} }
func newDataCipher() *dataCipher { func newDataCipher() *dataCipher {
key := make([]byte, 32) key := [32]byte{}
if _, err := rand.Read(key); err != nil { if _, err := rand.Read(key[:]); err != nil {
panic(err) panic(err)
} }
return newDataCipherFromKey(key) return newDataCipherFromKey(key)
} }
// key must be 32 bytes. // key must be 32 bytes.
func newDataCipherFromKey(key []byte) *dataCipher { func newDataCipherFromKey(key [32]byte) *dataCipher {
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key[:])
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -34,14 +35,14 @@ func newDataCipherFromKey(key []byte) *dataCipher {
return &dataCipher{key: key, aead: aead} return &dataCipher{key: key, aead: aead}
} }
func (sc *dataCipher) Key() []byte { func (sc *dataCipher) Key() [32]byte {
return sc.key return sc.key
} }
func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte { func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte {
const s = dataHeaderSize const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)] out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(dataStreamID, out[:s]) h.Marshal(out[:s])
sc.aead.Seal(out[s:s], out[:s], data, nil) sc.aead.Seal(out[s:s], out[:s], data, nil)
return out return out
} }

View File

@ -23,6 +23,7 @@ func TestDataCipher(t *testing.T) {
for _, plaintext := range testCases { for _, plaintext := range testCases {
h1 := xHeader{ h1 := xHeader{
StreamID: dataStreamID,
Counter: 235153, Counter: 235153,
SourceIP: 4, SourceIP: 4,
DestIP: 88, DestIP: 88,

View File

@ -1,26 +0,0 @@
package node
import "golang.org/x/crypto/nacl/box"
type routingCipher struct {
sharedKey [32]byte
}
func newRoutingCipher(privKey, pubKey []byte) routingCipher {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
return routingCipher{shared}
}
func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte {
const s = routingHeaderSize
out = out[:s+routingCipherOverhead+len(data)]
h.Marshal(routingStreamID, out[:s])
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey)
return out
}
func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = routingHeaderSize
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey)
}

View File

@ -1,6 +1,7 @@
package node package node
import ( import (
"io"
"log" "log"
"net" "net"
"net/netip" "net/netip"
@ -9,6 +10,48 @@ import (
"vppn/fasttime" "vppn/fasttime"
) )
// ----------------------------------------------------------------------------
type connWriter2 struct {
lock sync.Mutex
conn *net.UDPConn
}
func newConnWriter2(conn *net.UDPConn) *connWriter2 {
return &connWriter2{conn: conn}
}
func (w *connWriter2) WriteTo(packet []byte, addr netip.AddrPort) {
w.lock.Lock()
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()
}
// ----------------------------------------------------------------------------
type ifWriter struct {
lock sync.Mutex
iface io.ReadWriteCloser
}
func newIFWriter(iface io.ReadWriteCloser) *ifWriter {
return &ifWriter{iface: iface}
}
func (w *ifWriter) Write(packet []byte) {
w.lock.Lock()
if _, err := w.iface.Write(packet); err != nil {
log.Fatalf("Failed to write to interface: %v", err)
}
w.lock.Unlock()
}
// ----------------------------------------------------------------------------
// TODO: Delete below??
type connWriter struct { type connWriter struct {
*net.UDPConn *net.UDPConn
lock sync.Mutex lock sync.Mutex

View File

@ -5,30 +5,33 @@ import "unsafe"
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
const ( const (
routingStreamID = 2 controlStreamID = 2
routingHeaderSize = 24 controlHeaderSize = 24
routingCipherOverhead = 16 controlCipherOverhead = 16
dataStreamID = 1 dataStreamID = 1
dataHeaderSize = 12 dataHeaderSize = 12
dataCipherOverhead = 16 dataCipherOverhead = 16
forwardStreamID = 3
) )
// TODO: Rename
type xHeader struct { type xHeader struct {
StreamID byte
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic. Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte SourceIP byte
DestIP byte DestIP byte
} }
func (h *xHeader) Parse(b []byte) { func (h *xHeader) Parse(b []byte) {
h.StreamID = b[0]
h.Counter = *(*uint64)(unsafe.Pointer(&b[1])) h.Counter = *(*uint64)(unsafe.Pointer(&b[1]))
h.SourceIP = b[9] h.SourceIP = b[9]
h.DestIP = b[10] h.DestIP = b[10]
} }
func (h *xHeader) Marshal(streamID byte, buf []byte) { func (h *xHeader) Marshal(buf []byte) {
buf[0] = streamID buf[0] = h.StreamID
*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter *(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter
buf[9] = h.SourceIP buf[9] = h.SourceIP
buf[10] = h.DestIP buf[10] = h.DestIP
@ -40,7 +43,7 @@ func (h *xHeader) Marshal(streamID byte, buf []byte) {
const ( const (
headerSize = 24 headerSize = 24
streamData = 1 streamData = 1
streamRouting = 2 streamControl = 2
) )
type header struct { type header struct {

View File

@ -3,18 +3,17 @@ package node
import "testing" import "testing"
func TestHeaderMarshalParse(t *testing.T) { func TestHeaderMarshalParse(t *testing.T) {
nIn := header{ nIn := xHeader{
StreamID: 23,
Counter: 3212, Counter: 3212,
SourceIP: 34, SourceIP: 34,
DestIP: 200, DestIP: 200,
Forward: 1,
Stream: 44,
} }
buf := make([]byte, headerSize) buf := make([]byte, headerSize)
nIn.Marshal(buf) nIn.Marshal(buf)
nOut := header{} nOut := xHeader{}
nOut.Parse(buf) nOut.Parse(buf)
if nIn != nOut { if nIn != nOut {
t.Fatal(nIn, nOut) t.Fatal(nIn, nOut)

View File

@ -102,15 +102,19 @@ func main(netName, listenIP string, port uint16) {
log.Fatalf("Failed to open UDP port: %v", err) log.Fatalf("Failed to open UDP port: %v", err)
} }
routing := newRoutingTable() connWriter := newConnWriter2(conn)
ifWriter := newIFWriter(iface)
w := newConnWriter(conn, conf.PeerIP, routing) peers := remotePeers{}
r := newConnReader(conn, conf.PeerIP, routing)
router := newRouter(netName, conf, routing, w) for i := range peers {
peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter)
}
go newHubPoller(netName, conf, peers).Run()
go readFromConn(conn, peers)
readFromIFace(iface, peers)
go nodeConnReader(r, w, iface, router)
nodeIFaceReader(w, iface, router)
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -127,43 +131,39 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) { func readFromConn(conn *net.UDPConn, peers remotePeers) {
defer panicHandler() defer panicHandler()
var ( var (
remoteAddr netip.AddrPort remoteAddr netip.AddrPort
h header n int
err error
buf = make([]byte, bufferSize) buf = make([]byte, bufferSize)
data []byte data []byte
err error h xHeader
) )
for { for {
remoteAddr, h, data = r.Read(buf) n, remoteAddr, err = conn.ReadFromUDPAddrPort(buf[:bufferSize])
if err != nil {
if h.Forward != 0 { log.Fatalf("Failed to read from UDP port: %v", err)
w.Forward(h.DestIP, data)
continue
} }
switch h.Stream { data = buf[:n]
case streamData: if n < headerSize {
if _, err = iface.Write(data); err != nil { continue // Packet it soo short.
log.Printf("Malformed data from peer %d: %v", h.SourceIP, err)
}
case streamRouting:
router.HandlePacket(h.SourceIP, remoteAddr, data)
default:
log.Printf("Dropping unknown stream: %d", h.Stream)
} }
h.Parse(data)
peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
} }
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) {
var ( var (
buf = make([]byte, bufferSize) buf = make([]byte, bufferSize)
@ -173,16 +173,11 @@ func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) {
) )
for { for {
packet, remoteIP, err = readNextPacket(iface, buf) packet, remoteIP, err = readNextPacket(iface, buf)
if err != nil { if err != nil {
log.Fatalf("Failed to read from interface: %v", err) log.Fatalf("Failed to read from interface: %v", err)
} }
if remoteIP == w.localIP { peers[remoteIP].SendData(packet)
continue // Don't write to self.
}
w.WriteTo(remoteIP, streamData, packet)
} }
} }

View File

@ -16,10 +16,10 @@ const (
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type packetWrapper struct { type controlPacket struct {
SrcIP byte SrcIP byte
RemoteAddr netip.AddrPort RemoteAddr netip.AddrPort
Packet any Payload any
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -46,13 +46,13 @@ func (p pingPacket) Marshal(buf []byte) []byte {
return buf return buf
} }
func (p *pingPacket) Parse(buf []byte) error { func parsePingPacket(buf []byte) (p pingPacket, err error) {
if len(buf) != 41 { if len(buf) != 41 {
return errMalformedPacket return p, errMalformedPacket
} }
p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
copy(p.SharedKey[:], buf[9:41]) copy(p.SharedKey[:], buf[9:41])
return nil return
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
@ -78,12 +78,11 @@ func (p pongPacket) Marshal(buf []byte) []byte {
return buf return buf
} }
func (p *pongPacket) Parse(buf []byte) error { func parsePongPacket(buf []byte) (p pongPacket, err error) {
if len(buf) != 17 { if len(buf) != 17 {
return errMalformedPacket return p, errMalformedPacket
} }
p.SentAt = *(*int64)(unsafe.Pointer(&buf[1])) p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9])) p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9]))
return
return nil
} }

View File

@ -15,8 +15,8 @@ func TestPacketPing(t *testing.T) {
p := newPingPacket(sharedKey) p := newPingPacket(sharedKey)
out := p.Marshal(buf) out := p.Marshal(buf)
p2 := pingPacket{} p2, err := parsePingPacket(out)
if err := p2.Parse(out); err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -31,8 +31,8 @@ func TestPacketPong(t *testing.T) {
p := newPongPacket(123566) p := newPongPacket(123566)
out := p.Marshal(buf) out := p.Marshal(buf)
p2 := pongPacket{} p2, err := parsePongPacket(out)
if err := p2.Parse(out); err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

97
node/peer-pollhub.go Normal file
View File

@ -0,0 +1,97 @@
package node
import (
"encoding/json"
"io"
"log"
"net/http"
"net/url"
"time"
"vppn/m"
)
type hubPoller struct {
netName string
localIP byte
client *http.Client
req *http.Request
peers remotePeers
}
func newHubPoller(netName string, conf m.PeerConfig, peers remotePeers) *hubPoller {
u, err := url.Parse(conf.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
}
u.Path = "/peer/fetch-state/"
client := &http.Client{Timeout: 8 * time.Second}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
req.SetBasicAuth("", conf.APIKey)
return &hubPoller{
netName: netName,
localIP: conf.PeerIP,
client: client,
req: req,
peers: peers,
}
}
func (hp *hubPoller) Run() {
defer panicHandler()
state, err := loadNetworkState(hp.netName)
if err != nil {
log.Printf("Failed to load network state: %v", err)
log.Printf("Polling hub...")
hp.pollHub()
} else {
hp.applyNetworkState(state)
}
for range time.Tick(64 * time.Second) {
hp.pollHub()
}
}
func (hp *hubPoller) pollHub() {
var state m.NetworkState
log.Printf("Fetching peer state...")
resp, err := hp.client.Do(hp.req)
if err != nil {
log.Printf("Failed to fetch peer state: %v", err)
return
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
log.Printf("Failed to read body from hub: %v", err)
return
}
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("Failed to unmarshal response from hub: %v", err)
return
}
hp.applyNetworkState(state)
if err := storeNetworkState(hp.netName, state); err != nil {
log.Printf("Failed to store network state: %v", err)
}
}
func (hp *hubPoller) applyNetworkState(state m.NetworkState) {
for i := range state.Peers {
if i != int(hp.localIP) {
hp.peers[i].HandlePeerUpdate(state.Peers[i])
}
}
}

197
node/peer-supervisor.go Normal file
View File

@ -0,0 +1,197 @@
package node
import (
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
const (
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
type stateFunc func() stateFunc
type peerSuper struct {
*remotePeer
peer *m.Peer
remotePublic bool
peerData peerData
pktBuf []byte
encBuf []byte
}
func newPeerSuper(rp *remotePeer) *peerSuper {
return &peerSuper{
remotePeer: rp,
peer: nil,
pktBuf: make([]byte, bufferSize),
encBuf: make([]byte, bufferSize),
}
}
func (rp *peerSuper) Run() {
defer panicHandler()
state := rp.stateInit
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateInit() stateFunc {
//rp.logf("STATE: Init")
x := peerData{}
rp.shared.Store(&x)
rp.peerData.controlCipher = nil
rp.peerData.dataCipher = nil
rp.peerData.remoteAddr = zeroAddrPort
if rp.peer == nil {
return rp.stateDisconnected
}
var addr netip.Addr
addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
if rp.remotePublic {
rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
}
rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
return rp.stateSelectRole()
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateDisconnected() stateFunc {
//rp.logf("STATE: Disconnected")
for {
select {
case <-rp.controlPackets:
// Drop
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateSelectRole() stateFunc {
rp.logf("STATE: SelectRole")
if !rp.localPublic && !rp.remotePublic {
// TODO!
return rp.stateDisconnected
}
if !rp.localPublic {
return rp.stateServer
} else if !rp.remotePublic {
return rp.stateClient
}
if rp.localIP < rp.peer.PeerIP {
return rp.stateClient
}
return rp.stateServer
}
// The remote is a server.
func (rp *peerSuper) stateServer() stateFunc {
rp.logf("STATE: Server")
rp.peerData.dataCipher = newDataCipher()
rp.updateShared()
var (
pingTimer = time.NewTimer(pingInterval)
ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
)
defer pingTimer.Stop()
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
for {
select {
case <-pingTimer.C:
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
pingTimer.Reset(pingInterval)
case <-rp.controlPackets:
// Ignore
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
// The remote is a client.
func (rp *peerSuper) stateClient() stateFunc {
rp.logf("STATE: Client")
rp.updateShared()
// TODO: Could use timeout to set dataCipher to nil.
var currentKey = [32]byte{}
for {
select {
case cPkt := <-rp.controlPackets:
if cPkt.RemoteAddr != rp.peerData.remoteAddr {
rp.peerData.remoteAddr = cPkt.RemoteAddr
rp.logf("Got new remote address: %v", cPkt.RemoteAddr)
rp.updateShared()
}
ping, ok := cPkt.Payload.(pingPacket)
if !ok {
continue
}
if ping.SharedKey != currentKey {
rp.logf("Connected with new shared key")
currentKey = ping.SharedKey
rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
rp.updateShared()
}
rp.sendControlPacket(newPongPacket(ping.SentAt))
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) updateShared() {
data := rp.peerData
rp.shared.Store(&data)
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
buf := pkt.Marshal(rp.pktBuf)
h := xHeader{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
}
buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
}

View File

@ -1 +1,206 @@
package node package node
import (
"fmt"
"log"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
type remotePeers [256]*remotePeer
type peerData struct {
controlCipher *controlCipher
dataCipher *dataCipher
remoteAddr netip.AddrPort
}
type remotePeer struct {
// Immutable data.
localIP byte
remoteIP byte
privKey []byte
localPublic bool // True if local node is public.
iface *ifWriter
conn *connWriter2
// Shared state.
shared *atomic.Pointer[peerData]
// Only used in HandlePeerUpdate.
peerVersion int64
// Only used in HandlePacket / Not synchronized.
dupCheck *dupCheck
decryptBuf []byte
// Only used in SendData / Not synchronized.
encryptBuf []byte
// Used for sending control and data packets. Atomic access only.
counter uint64
// For communicating with the supervisor thread.
peerUpdates chan *m.Peer
controlPackets chan controlPacket
}
func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter2) *remotePeer {
rp := &remotePeer{
localIP: conf.PeerIP,
remoteIP: remoteIP,
privKey: conf.EncPrivKey,
localPublic: addrIsValid(conf.PublicIP),
iface: iface,
conn: conn,
shared: &atomic.Pointer[peerData]{},
dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize),
encryptBuf: make([]byte, bufferSize),
counter: uint64(time.Now().Unix()) << 30,
peerUpdates: make(chan *m.Peer),
controlPackets: make(chan controlPacket, 512),
}
pd := peerData{}
rp.shared.Store(&pd)
go newPeerSuper(rp).Run()
return rp
}
func (rp *remotePeer) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", rp.remoteIP)+msg, args...)
}
func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
if peer != nil && peer.Version != rp.peerVersion {
rp.peerUpdates <- peer
rp.peerVersion = peer.Version
}
}
// ----------------------------------------------------------------------------
// HandlePacket accepts a raw data packet coming in from the network.
//
// This function is called by a single thread.
func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h xHeader, data []byte) {
switch h.StreamID {
case controlStreamID:
rp.handleControlPacket(addr, h, data)
case dataStreamID:
rp.handleDataPacket(data)
case forwardStreamID:
fallthrough
// TODO
//rp.handleForwardPacket(h, data)
default:
rp.logf("Unknown stream ID: %d", h.StreamID)
}
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h xHeader, data []byte) {
shared := rp.shared.Load()
if shared.controlCipher == nil {
rp.logf("Not connected (control).")
return
}
out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt control packet.")
return
}
if len(out) == 0 {
rp.logf("Empty control packet from: %d", h.SourceIP)
return
}
if rp.dupCheck.IsDup(h.Counter) {
rp.logf("Duplicate control packet: %d", h.Counter)
return
}
if h.DestIP != rp.localIP {
// TODO: Forward control packet.
// TODO: Probably this should be dropped.
// Control packets should be forwarded as data for efficiency.
return
}
pkt := controlPacket{
SrcIP: h.SourceIP,
RemoteAddr: addr,
}
var err error
switch out[0] {
case packetTypePing:
pkt.Payload, err = parsePingPacket(out)
case packetTypePong:
pkt.Payload, err = parsePongPacket(out)
default:
rp.logf("Unknown control packet type: %d", out[0])
return
}
if err != nil {
rp.logf("Failed to parse control packet: %v", err)
return
}
select {
case rp.controlPackets <- pkt:
default:
rp.logf("Dropping control packet.")
}
}
func (rp *remotePeer) handleDataPacket(data []byte) {
shared := rp.shared.Load()
if shared.dataCipher == nil {
rp.logf("Not connected (recv).")
return
}
dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt data packet.")
return
}
rp.iface.Write(dec)
}
// ----------------------------------------------------------------------------
// SendData sends data coming from the interface going to the network.
//
// This function is called by a single thread.
func (rp *remotePeer) SendData(data []byte) {
shared := rp.shared.Load()
if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send).")
return
}
h := xHeader{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
}
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
rp.conn.WriteTo(enc, shared.remoteAddr)
}

View File

@ -8,12 +8,6 @@ import (
"vppn/m" "vppn/m"
) )
const (
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
type routingPacketWrapper struct { type routingPacketWrapper struct {
routingPacket routingPacket
Addr netip.AddrPort // Source. Addr netip.AddrPort // Source.
@ -113,8 +107,6 @@ func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
type stateFunc func() stateFunc
func (s *peerSupervisor) stateInit() stateFunc { func (s *peerSupervisor) stateInit() stateFunc {
if s.peer == nil { if s.peer == nil {
return s.stateDisconnected return s.stateDisconnected
@ -316,12 +308,12 @@ func (s *peerSupervisor) updateRoutingTable(up bool) {
func (s *peerSupervisor) sendPing() uint64 { func (s *peerSupervisor) sendPing() uint64 {
traceID := newTraceID() traceID := newTraceID()
pkt := newRoutingPacket(packetTypePing, traceID) pkt := newRoutingPacket(packetTypePing, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf))
s.pingTimer.Reset(pingInterval) s.pingTimer.Reset(pingInterval)
return traceID return traceID
} }
func (s *peerSupervisor) sendPong(traceID uint64) { func (s *peerSupervisor) sendPong(traceID uint64) {
pkt := newRoutingPacket(packetTypePong, traceID) pkt := newRoutingPacket(packetTypePong, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) s.w.WriteTo(s.peer.PeerIP, streamControl, pkt.Marshal(s.buf))
} }

View File

@ -19,7 +19,7 @@ type peer struct {
Up bool // No data will be sent to peers that are down. Up bool // No data will be sent to peers that are down.
Addr netip.AddrPort // If we have direct connection, otherwise use mediator. Addr netip.AddrPort // If we have direct connection, otherwise use mediator.
Mediator bool // True if the peer will mediate. Mediator bool // True if the peer will mediate.
RoutingCipher routingCipher RoutingCipher controlCipher
DataCipher dataCipher DataCipher dataCipher
// TODO: Deprecated below. // TODO: Deprecated below.