sym-encryption #1

Merged
johnnylee merged 18 commits from sym-encryption into main 2024-12-24 18:37:44 +00:00
13 changed files with 443 additions and 135 deletions
Showing only changes of commit 0ae0f31eae - Show all commits

61
node/cipher-data.go Normal file
View File

@ -0,0 +1,61 @@
package node
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
)
type dataCipher struct {
key []byte
aead cipher.AEAD
}
func newDataCipher() *dataCipher {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
panic(err)
}
return newDataCipherFromKey(key)
}
// key must be 32 bytes.
func newDataCipherFromKey(key []byte) *dataCipher {
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
panic(err)
}
return &dataCipher{key: key, aead: aead}
}
func (sc *dataCipher) Key() []byte {
return sc.key
}
func (sc *dataCipher) Encrypt(h xHeader, data, out []byte) []byte {
const s = dataHeaderSize
out = out[:s+dataCipherOverhead+len(data)]
h.Marshal(dataStreamID, out[:s])
sc.aead.Seal(out[s:s], out[:s], data, nil)
return out
}
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = dataHeaderSize
if len(encrypted) < s+dataCipherOverhead {
ok = false
return
}
var err error
data, err = sc.aead.Open(out[:0], encrypted[:s], encrypted[s:], nil)
ok = err == nil
return
}

View File

@ -22,7 +22,7 @@ func TestDataCipher(t *testing.T) {
}
for _, plaintext := range testCases {
h1 := dataHeader{
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@ -31,11 +31,13 @@ func TestDataCipher(t *testing.T) {
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
h2 := xHeader{}
h2.Parse(encrypted)
dc2 := newDataCipherFromKey(dc1.Key())
decrypted, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
decrypted, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if !ok {
t.Fatal(ok)
}
@ -64,7 +66,7 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) {
}
for _, plaintext := range testCases {
h1 := dataHeader{
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@ -73,14 +75,14 @@ func TestDataCipher_ModifyCiphertext(t *testing.T) {
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
encrypted[mrand.IntN(len(encrypted))]++
dc2 := newDataCipherFromKey(dc1.Key())
_, h2, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
_, ok := dc2.Decrypt(encrypted, make([]byte, bufferSize-dataHeaderSize))
if ok {
t.Fatal(ok, h2)
t.Fatal(ok)
}
}
}
@ -89,14 +91,14 @@ func TestDataCipher_ShortCiphertext(t *testing.T) {
dc1 := newDataCipher()
shortText := make([]byte, dataHeaderSize+dataCipherOverhead-1)
rand.Read(shortText)
_, _, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
_, ok := dc1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkDataCipher_Encrypt(b *testing.B) {
h1 := dataHeader{
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@ -110,12 +112,12 @@ func BenchmarkDataCipher_Encrypt(b *testing.B) {
dc1 := newDataCipher()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkDataCipher_Decrypt(b *testing.B) {
h1 := dataHeader{
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
@ -127,12 +129,12 @@ func BenchmarkDataCipher_Decrypt(b *testing.B) {
encrypted := make([]byte, bufferSize)
dc1 := newDataCipher()
encrypted = dc1.Encrypt(&h1, plaintext, encrypted)
encrypted = dc1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _, _ = dc1.Decrypt(encrypted, decrypted)
decrypted, _ = dc1.Decrypt(encrypted, decrypted)
}
}

26
node/cipher-routing.go Normal file
View File

@ -0,0 +1,26 @@
package node
import "golang.org/x/crypto/nacl/box"
type routingCipher struct {
sharedKey [32]byte
}
func newRoutingCipher(privKey, pubKey []byte) routingCipher {
shared := [32]byte{}
box.Precompute(&shared, (*[32]byte)(pubKey), (*[32]byte)(privKey))
return routingCipher{shared}
}
func (rc routingCipher) Encrypt(h xHeader, data, out []byte) []byte {
const s = routingHeaderSize
out = out[:s+routingCipherOverhead+len(data)]
h.Marshal(routingStreamID, out[:s])
box.SealAfterPrecomputation(out[s:s], data, (*[24]byte)(out[:s]), &rc.sharedKey)
return out
}
func (rc routingCipher) Decrypt(encrypted, out []byte) (data []byte, ok bool) {
const s = routingHeaderSize
return box.OpenAfterPrecomputation(out[:0], encrypted[s:], (*[24]byte)(encrypted[:s]), &rc.sharedKey)
}

114
node/cipher-routing_test.go Normal file
View File

@ -0,0 +1,114 @@
package node
import (
"bytes"
"crypto/rand"
"testing"
"golang.org/x/crypto/nacl/box"
)
func newRoutingCipherForTesting() (c1, c2 routingCipher) {
pubKey1, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
pubKey2, privKey2, err := box.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
return newRoutingCipher(privKey1[:], pubKey2[:]),
newRoutingCipher(privKey2[:], pubKey1[:])
}
func TestRoutingCipher(t *testing.T) {
c1, c2 := newRoutingCipherForTesting()
maxSizePlaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
rand.Read(maxSizePlaintext)
testCases := [][]byte{
make([]byte, 0),
{1},
{255},
{1, 2, 3, 4, 5},
[]byte("Hello world"),
maxSizePlaintext,
}
for _, plaintext := range testCases {
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
decrypted, ok := c2.Decrypt(encrypted, make([]byte, bufferSize))
if !ok {
t.Fatal(ok)
}
if !bytes.Equal(decrypted, plaintext) {
t.Fatal("not equal")
}
}
}
func TestRoutingCipher_ShortCiphertext(t *testing.T) {
c1, _ := newRoutingCipherForTesting()
shortText := make([]byte, routingHeaderSize+routingCipherOverhead-1)
rand.Read(shortText)
_, ok := c1.Decrypt(shortText, make([]byte, bufferSize))
if ok {
t.Fatal(ok)
}
}
func BenchmarkRoutingCipher_Encrypt(b *testing.B) {
c1, _ := newRoutingCipherForTesting()
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
encrypted = c1.Encrypt(h1, plaintext, encrypted)
}
}
func BenchmarkRoutingCipher_Decrypt(b *testing.B) {
c1, c2 := newRoutingCipherForTesting()
h1 := xHeader{
Counter: 235153,
SourceIP: 4,
DestIP: 88,
}
plaintext := make([]byte, bufferSize-routingHeaderSize-routingCipherOverhead)
rand.Read(plaintext)
encrypted := make([]byte, bufferSize)
encrypted = c1.Encrypt(h1, plaintext, encrypted)
decrypted := make([]byte, bufferSize)
b.ResetTimer()
for i := 0; i < b.N; i++ {
decrypted, _ = c2.Decrypt(encrypted, decrypted)
}
}

6
node/cipher.go Normal file
View File

@ -0,0 +1,6 @@
package node
type packetCipher interface {
Encrypt(h xHeader, data, out []byte) []byte
Decrypt(encrypted, out []byte) (data []byte, ok bool)
}

View File

@ -35,6 +35,33 @@ func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *conn
return w
}
/*
func (w *connWriter) SendRouting(remoteIP byte, data []byte) {
dstPeer := w.routing.Get(remoteIP)
if dstPeer == nil {
log.Printf("No peer: %d", remoteIP)
return
}
var viaPeer *peer
if dstPeer.Addr == zeroAddrPort {
viaPeer = w.routing.Mediator()
if viaPeer == nil {
log.Printf("No mediator: %d", remoteIP)
return
}
}
w.sendRouting(dstPeer, viaPeer, data)
}
*/
func (w *connWriter) SendData(remoteIP byte, data []byte) {
// TODO
}
// TODO: deprecated
func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
dstPeer := w.routing.Get(remoteIP)
if dstPeer == nil {
@ -50,11 +77,11 @@ func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
var viaPeer *peer
if dstPeer.Mediated {
viaPeer = w.routing.mediator.Load()
if viaPeer == nil || viaPeer.Addr == nil {
if viaPeer == nil || viaPeer.Addr == zeroAddrPort {
log.Printf("Mediator not connected")
return
}
} else if dstPeer.Addr == nil {
} else if dstPeer.Addr == zeroAddrPort {
log.Printf("Peer doesn't have address: %d", remoteIP)
return
}
@ -62,6 +89,7 @@ func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) {
w.WriteToPeer(dstPeer, viaPeer, stream, data)
}
// TODO: deprecated
func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byte) {
w.lock.Lock()
@ -89,20 +117,21 @@ func (w *connWriter) WriteToPeer(dstPeer, viaPeer *peer, stream byte, data []byt
addr = viaPeer.Addr
}
if _, err := w.WriteToUDPAddrPort(buf, *addr); err != nil {
if _, err := w.WriteToUDPAddrPort(buf, addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()
}
// TODO: deprecated
func (w *connWriter) Forward(dstIP byte, packet []byte) {
dstPeer := w.routing.Get(dstIP)
if dstPeer == nil || dstPeer.Addr == nil {
if dstPeer == nil || dstPeer.Addr == zeroAddrPort {
log.Printf("No peer: %d", dstIP)
return
}
if _, err := w.WriteToUDPAddrPort(packet, *dstPeer.Addr); err != nil {
if _, err := w.WriteToUDPAddrPort(packet, dstPeer.Addr); err != nil {
log.Fatalf("Failed to write to UDP port: %v", err)
}
}

View File

@ -1,97 +0,0 @@
package node
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"unsafe"
)
// ----------------------------------------------------------------------------
const (
dataStreamID = 1
dataHeaderSize = 12
dataCipherOverhead = 16 + 1
)
type dataHeader struct {
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte
DestIP byte
}
func (h *dataHeader) Parse(b []byte) {
h.Counter = *(*uint64)(unsafe.Pointer(&b[0]))
h.SourceIP = b[8]
h.DestIP = b[9]
}
func (h *dataHeader) Marshal(buf []byte) {
*(*uint64)(unsafe.Pointer(&buf[0])) = h.Counter
buf[8] = h.SourceIP
buf[9] = h.DestIP
buf[10] = 0
buf[11] = 0
}
// ----------------------------------------------------------------------------
type dataCipher struct {
key []byte
aead cipher.AEAD
}
func newDataCipher() *dataCipher {
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
panic(err)
}
return newDataCipherFromKey(key)
}
// key must be 32 bytes.
func newDataCipherFromKey(key []byte) *dataCipher {
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
panic(err)
}
return &dataCipher{key: key, aead: aead}
}
func (sc *dataCipher) Key() []byte {
return sc.key
}
func (sc *dataCipher) Encrypt(h *dataHeader, data, out []byte) []byte {
out = out[:dataHeaderSize+dataCipherOverhead+len(data)]
out[0] = dataStreamID
h.Marshal(out[1:])
const s = dataHeaderSize
sc.aead.Seal(out[1+s:1+s], out[1:1+s], data, nil)
return out
}
func (sc *dataCipher) Decrypt(encrypted, out []byte) (data []byte, h dataHeader, ok bool) {
const s = dataHeaderSize
if len(encrypted) < s+dataCipherOverhead {
ok = false
return
}
h.Parse(encrypted[1 : 1+s])
var err error
data, err = sc.aead.Open(out[:0], encrypted[1:1+s], encrypted[1+s:], nil)
ok = err == nil
return
}

View File

@ -2,6 +2,41 @@ package node
import "unsafe"
// ----------------------------------------------------------------------------
const (
routingStreamID = 2
routingHeaderSize = 24
routingCipherOverhead = 16
dataStreamID = 1
dataHeaderSize = 12
dataCipherOverhead = 16
)
// TODO: Rename
type xHeader struct {
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte
DestIP byte
}
func (h *xHeader) Parse(b []byte) {
h.Counter = *(*uint64)(unsafe.Pointer(&b[1]))
h.SourceIP = b[9]
h.DestIP = b[10]
}
func (h *xHeader) Marshal(streamID byte, buf []byte) {
buf[0] = streamID
*(*uint64)(unsafe.Pointer(&buf[1])) = h.Counter
buf[9] = h.SourceIP
buf[10] = h.DestIP
buf[11] = 0
}
// ----------------------------------------------------------------------------
// TODO: Remove this code.
const (
headerSize = 24
streamData = 1

89
node/packets.go Normal file
View File

@ -0,0 +1,89 @@
package node
import (
"errors"
"net/netip"
"time"
"unsafe"
)
var errMalformedPacket = errors.New("malformed packet")
const (
packetTypePing = iota + 1
packetTypePong
)
// ----------------------------------------------------------------------------
type packetWrapper struct {
SrcIP byte
RemoteAddr netip.AddrPort
Packet any
}
// ----------------------------------------------------------------------------
// A pingPacket is sent from a node acting as a client, to a node acting
// as a server. It always contains the shared key the client is expecting
// to use for data encryption with the server.
type pingPacket struct {
SentAt int64 // UnixMilli.
SharedKey [32]byte
}
func newPingPacket(sharedKey []byte) (pp pingPacket) {
pp.SentAt = time.Now().UnixMilli()
copy(pp.SharedKey[:], sharedKey)
return
}
func (p pingPacket) Marshal(buf []byte) []byte {
buf = buf[:41]
buf[0] = packetTypePing
*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt)
copy(buf[9:41], p.SharedKey[:])
return buf
}
func (p *pingPacket) Parse(buf []byte) error {
if len(buf) != 41 {
return errMalformedPacket
}
p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
copy(p.SharedKey[:], buf[9:41])
return nil
}
// ----------------------------------------------------------------------------
// A pongPacket is sent by a node in a server role in response to a pingPacket.
type pongPacket struct {
SentAt int64 // UnixMilli.
RecvdAt int64 // UnixMilli.
}
func newPongPacket(sentAt int64) (pp pongPacket) {
pp.SentAt = sentAt
pp.RecvdAt = time.Now().UnixMilli()
return
}
func (p pongPacket) Marshal(buf []byte) []byte {
buf = buf[:17]
buf[0] = packetTypePong
*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.SentAt)
*(*uint64)(unsafe.Pointer(&buf[9])) = uint64(p.RecvdAt)
return buf
}
func (p *pongPacket) Parse(buf []byte) error {
if len(buf) != 17 {
return errMalformedPacket
}
p.SentAt = *(*int64)(unsafe.Pointer(&buf[1]))
p.RecvdAt = *(*int64)(unsafe.Pointer(&buf[9]))
return nil
}

42
node/packets_test.go Normal file
View File

@ -0,0 +1,42 @@
package node
import (
"crypto/rand"
"reflect"
"testing"
)
func TestPacketPing(t *testing.T) {
sharedKey := make([]byte, 32)
rand.Read(sharedKey)
buf := make([]byte, bufferSize)
p := newPingPacket(sharedKey)
out := p.Marshal(buf)
p2 := pingPacket{}
if err := p2.Parse(out); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, p2) {
t.Fatal(p, p2)
}
}
func TestPacketPong(t *testing.T) {
buf := make([]byte, bufferSize)
p := newPongPacket(123566)
out := p.Marshal(buf)
p2 := pongPacket{}
if err := p2.Parse(out); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, p2) {
t.Fatal(p, p2)
}
}

View File

@ -36,7 +36,7 @@ type peerSupervisor struct {
// Peer-related items.
version int64 // Ony accessed in HandlePeerUpdate.
peer *m.Peer
remoteAddrPort *netip.AddrPort
remoteAddrPort netip.AddrPort
mediated bool
sharedKey []byte
@ -123,9 +123,9 @@ func (s *peerSupervisor) stateInit() stateFunc {
addr, ok := netip.AddrFromSlice(s.peer.PublicIP)
if ok {
addrPort := netip.AddrPortFrom(addr, s.peer.Port)
s.remoteAddrPort = &addrPort
s.remoteAddrPort = addrPort
} else {
s.remoteAddrPort = nil
s.remoteAddrPort = zeroAddrPort
}
s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey)
@ -153,7 +153,7 @@ func (s *peerSupervisor) stateSelectRole() stateFunc {
s.logf("STATE: SelectRole")
s.updateRoutingTable(false)
if s.remoteAddrPort != nil {
if s.remoteAddrPort != zeroAddrPort {
s.mediated = false
// If both remote and local are public, one side acts as client, and one
@ -186,7 +186,7 @@ func (s *peerSupervisor) stateAccept() stateFunc {
switch pkt.Type {
case packetTypePing:
s.remoteAddrPort = &pkt.Addr
s.remoteAddrPort = pkt.Addr
s.updateRoutingTable(true)
s.sendPong(pkt.TraceID)
return s.stateConnected
@ -256,8 +256,8 @@ func (s *peerSupervisor) stateConnected() stateFunc {
// Server should always follow remote port.
if s.localPublic {
if pkt.Addr != *s.remoteAddrPort {
s.remoteAddrPort = &pkt.Addr
if pkt.Addr != s.remoteAddrPort {
s.remoteAddrPort = pkt.Addr
s.updateRoutingTable(true)
}
}

View File

@ -12,12 +12,18 @@ import (
"vppn/m"
)
var zeroAddrPort = netip.AddrPort{}
type peer struct {
Up bool // No data will be sent to peers that are down.
Mediator bool
IP byte // The VPN IP.
Up bool // No data will be sent to peers that are down.
Addr netip.AddrPort // If we have direct connection, otherwise use mediator.
Mediator bool // True if the peer will mediate.
RoutingCipher routingCipher
DataCipher dataCipher
// TODO: Deprecated below.
Mediated bool
IP byte
Addr *netip.AddrPort // If we have direct connection, otherwise use mediator.
SharedKey []byte
}
@ -48,6 +54,10 @@ func (r *routingTable) Set(ip byte, p *peer) {
r.table[ip].Store(p)
}
func (r *routingTable) Mediator() *peer {
return r.mediator.Load()
}
// ----------------------------------------------------------------------------
type router struct {

View File

@ -1,18 +1,9 @@
package node
import (
"errors"
"unsafe"
)
var errMalformedPacket = errors.New("malformed packet")
const (
// Used to maintain connection.
packetTypePing = iota + 1
packetTypePong
)
type routingPacket struct {
Type byte // One of the packetType* constants.
TraceID uint64 // For matching requests and responses.