refactor-for-testability #3

Merged
johnnylee merged 26 commits from refactor-for-testability into main 2025-03-01 20:02:27 +00:00
19 changed files with 110 additions and 333 deletions
Showing only changes of commit e1a5f50e1a - Show all commits

View File

@ -1,40 +1,44 @@
package peer
import (
"io"
"log"
"net/netip"
"sync/atomic"
)
type connReader struct {
conn udpReader
iface ifWriter
sender encryptedPacketSender
super controlMsgHandler
// Input
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
// Output
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
iface io.Writer
handleControlMsg func(fromIP byte, pkt any)
localIP byte
peers [256]*atomic.Pointer[remotePeer]
rt *atomic.Pointer[routingTable]
buf []byte
decBuf []byte
}
func newConnReader(
conn udpReader,
ifWriter ifWriter,
sender encryptedPacketSender,
super controlMsgHandler,
localIP byte,
peers [256]*atomic.Pointer[remotePeer],
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
iface io.Writer,
handleControlMsg func(fromIP byte, pkt any),
rt *atomic.Pointer[routingTable],
) *connReader {
return &connReader{
conn: conn,
iface: ifWriter,
sender: sender,
super: super,
localIP: localIP,
peers: peers,
buf: make([]byte, bufferSize),
decBuf: make([]byte, bufferSize),
readFromUDPAddrPort: readFromUDPAddrPort,
writeToUDPAddrPort: writeToUDPAddrPort,
iface: iface,
handleControlMsg: handleControlMsg,
localIP: rt.Load().LocalIP,
rt: rt,
buf: newBuf(),
decBuf: newBuf(),
}
}
@ -44,13 +48,11 @@ func (r *connReader) Run() {
}
}
func (r *connReader) logf(s string, args ...any) {
log.Printf("[ConnReader] "+s, args...)
}
func (r *connReader) handleNextPacket() {
buf := r.buf[:bufferSize]
n, remoteAddr, err := r.conn.ReadFromUDPAddrPort(buf)
log.Printf("Getting next packet...")
n, remoteAddr, err := r.readFromUDPAddrPort(buf)
log.Printf("Packet from %v...", remoteAddr)
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
@ -64,23 +66,22 @@ func (r *connReader) handleNextPacket() {
buf = buf[:n]
h := parseHeader(buf)
peer := r.peers[h.SourceIP].Load()
rt := r.rt.Load()
peer := rt.Peers[h.SourceIP]
switch h.StreamID {
case controlStreamID:
r.handleControlPacket(peer, remoteAddr, h, buf)
r.handleControlPacket(remoteAddr, peer, h, buf)
case dataStreamID:
r.handleDataPacket(peer, h, buf)
r.handleDataPacket(rt, peer, h, buf)
default:
r.logf("Unknown stream ID: %d", h.StreamID)
}
}
func (r *connReader) handleControlPacket(
peer *remotePeer,
addr netip.AddrPort,
remoteAddr netip.AddrPort,
peer remotePeer,
h header,
enc []byte,
) {
@ -93,22 +94,27 @@ func (r *connReader) handleControlPacket(
return
}
msg, err := decryptControlPacket(peer, addr, h, enc, r.decBuf)
msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt control packet: %v", err)
return
}
r.super.HandleControlMsg(msg)
r.handleControlMsg(h.SourceIP, msg)
}
func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) {
func (r *connReader) handleDataPacket(
rt *routingTable,
peer remotePeer,
h header,
enc []byte,
) {
if !peer.Up {
r.logf("Not connected (recv).")
return
}
data, err := decryptDataPacket(peer, h, enc, r.decBuf)
data, err := peer.DecryptDataPacket(h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt data packet: %v", err)
return
@ -121,11 +127,15 @@ func (r *connReader) handleDataPacket(peer *remotePeer, h header, enc []byte) {
return
}
destPeer := r.peers[h.DestIP].Load()
if !destPeer.Up {
r.logf("Not connected (relay): %d", destPeer.IP)
relay, ok := rt.GetRelay()
if !ok {
r.logf("Relay not available.")
return
}
r.sender.SendEncryptedDataPacket(data, destPeer)
r.writeToUDPAddrPort(data, relay.DirectAddr)
}
func (r *connReader) logf(format string, args ...any) {
log.Printf("[ConnReader] "+format, args...)
}

View File

@ -1,141 +0,0 @@
package peer
import (
"io"
"log"
"net/netip"
"sync/atomic"
)
type ConnReader struct {
// Input
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error)
// Output
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
iface io.Writer
handleControlMsg func(fromIP byte, pkt any)
localIP byte
rt *atomic.Pointer[routingTable]
buf []byte
decBuf []byte
}
func NewConnReader(
readFromUDPAddrPort func([]byte) (int, netip.AddrPort, error),
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
iface io.Writer,
handleControlMsg func(fromIP byte, pkt any),
rt *atomic.Pointer[routingTable],
) *ConnReader {
return &ConnReader{
readFromUDPAddrPort: readFromUDPAddrPort,
writeToUDPAddrPort: writeToUDPAddrPort,
iface: iface,
handleControlMsg: handleControlMsg,
localIP: rt.Load().LocalIP,
rt: rt,
buf: newBuf(),
decBuf: newBuf(),
}
}
func (r *ConnReader) Run() {
for {
r.handleNextPacket()
}
}
func (r *ConnReader) handleNextPacket() {
buf := r.buf[:bufferSize]
log.Printf("Getting next packet...")
n, remoteAddr, err := r.readFromUDPAddrPort(buf)
log.Printf("Packet from %v...", remoteAddr)
if err != nil {
log.Fatalf("Failed to read from UDP port: %v", err)
}
if n < headerSize {
return
}
remoteAddr = netip.AddrPortFrom(remoteAddr.Addr().Unmap(), remoteAddr.Port())
buf = buf[:n]
h := parseHeader(buf)
rt := r.rt.Load()
peer := rt.Peers[h.SourceIP]
switch h.StreamID {
case controlStreamID:
r.handleControlPacket(remoteAddr, peer, h, buf)
case dataStreamID:
r.handleDataPacket(rt, peer, h, buf)
default:
r.logf("Unknown stream ID: %d", h.StreamID)
}
}
func (r *ConnReader) handleControlPacket(
remoteAddr netip.AddrPort,
peer remotePeer,
h header,
enc []byte,
) {
if peer.ControlCipher == nil {
return
}
if h.DestIP != r.localIP {
r.logf("Incorrect destination IP on control packet: %d", h.DestIP)
return
}
msg, err := peer.DecryptControlPacket(remoteAddr, h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt control packet: %v", err)
return
}
r.handleControlMsg(h.SourceIP, msg)
}
func (r *ConnReader) handleDataPacket(
rt *routingTable,
peer remotePeer,
h header,
enc []byte,
) {
if !peer.Up {
r.logf("Not connected (recv).")
return
}
data, err := peer.DecryptDataPacket(h, enc, r.decBuf)
if err != nil {
r.logf("Failed to decrypt data packet: %v", err)
return
}
if h.DestIP == r.localIP {
if _, err := r.iface.Write(data); err != nil {
log.Fatalf("Failed to write to interface: %v", err)
}
return
}
relay, ok := rt.GetRelay()
if !ok {
r.logf("Relay not available.")
return
}
r.writeToUDPAddrPort(data, relay.DirectAddr)
}
func (r *ConnReader) logf(format string, args ...any) {
log.Printf("[ConnReader] "+format, args...)
}

View File

@ -37,7 +37,7 @@ func generateKeys() cryptoKeys {
func encryptControlPacket(
localIP byte,
peer *remotePeer,
pkt Marshaller,
pkt marshaller,
tmp []byte,
out []byte,
) []byte {

View File

@ -13,12 +13,12 @@ func newRoutePairForTesting() (*remotePeer, *remotePeer) {
keys1 := generateKeys()
keys2 := generateKeys()
r1 := NewRemotePeer(1)
r1 := newRemotePeer(1)
r1.PubSignKey = keys1.PubSignKey
r1.ControlCipher = newControlCipher(keys1.PrivKey, keys2.PubKey)
r1.DataCipher = newDataCipher()
r2 := NewRemotePeer(2)
r2 := newRemotePeer(2)
r2.PubSignKey = keys2.PubSignKey
r2.ControlCipher = newControlCipher(keys2.PrivKey, keys1.PubKey)
r2.DataCipher = r1.DataCipher

View File

@ -27,3 +27,7 @@ var multicastAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(
func newBuf() []byte {
return make([]byte, bufferSize)
}
type marshaller interface {
Marshal([]byte) []byte
}

View File

@ -7,7 +7,7 @@ import (
"sync/atomic"
)
type IFReader struct {
type ifReader struct {
iface io.Reader
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
rt *atomic.Pointer[routingTable]
@ -15,22 +15,22 @@ type IFReader struct {
buf2 []byte
}
func NewIFReader(
func newIFReader(
iface io.Reader,
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
rt *atomic.Pointer[routingTable],
) *IFReader {
return &IFReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()}
) *ifReader {
return &ifReader{iface, writeToUDPAddrPort, rt, newBuf(), newBuf()}
}
func (r *IFReader) Run() {
func (r *ifReader) Run() {
packet := newBuf()
for {
r.handleNextPacket(packet)
}
}
func (r *IFReader) handleNextPacket(packet []byte) {
func (r *ifReader) handleNextPacket(packet []byte) {
packet = r.readNextPacket(packet)
remoteIP, ok := r.parsePacket(packet)
if !ok {
@ -60,7 +60,7 @@ func (r *IFReader) handleNextPacket(packet []byte) {
r.writeToUDPAddrPort(enc, relay.DirectAddr)
}
func (r *IFReader) readNextPacket(buf []byte) []byte {
func (r *ifReader) readNextPacket(buf []byte) []byte {
n, err := r.iface.Read(buf[:cap(buf)])
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
@ -69,7 +69,7 @@ func (r *IFReader) readNextPacket(buf []byte) []byte {
return buf[:n]
}
func (r *IFReader) parsePacket(buf []byte) (byte, bool) {
func (r *ifReader) parsePacket(buf []byte) (byte, bool) {
n := len(buf)
if n == 0 {
return 0, false
@ -98,6 +98,6 @@ func (r *IFReader) parsePacket(buf []byte) (byte, bool) {
}
}
func (*IFReader) logf(s string, args ...any) {
func (*ifReader) logf(s string, args ...any) {
log.Printf("[IFReader] "+s, args...)
}

View File

@ -3,7 +3,6 @@ package peer
import (
"fmt"
"io"
"log"
"net"
"os"
"syscall"
@ -11,45 +10,6 @@ import (
"golang.org/x/sys/unix"
)
// Get next packet, returning packet, ip, and possible error.
func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) {
var (
version byte
ip byte
)
for {
n, err := iface.Read(buf[:cap(buf)])
if err != nil {
return nil, ip, err
}
buf = buf[:n]
version = buf[0] >> 4
switch version {
case 4:
if n < 20 {
log.Printf("Short IPv4 packet: %d", len(buf))
continue
}
ip = buf[19]
case 6:
if len(buf) < 40 {
log.Printf("Short IPv6 packet: %d", len(buf))
continue
}
ip = buf[39]
default:
log.Printf("Invalid IP packet version: %v", version)
continue
}
return buf, ip, nil
}
}
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {
if len(network) != 4 {
return nil, fmt.Errorf("expected network to be 4 bytes, got %d", len(network))

View File

@ -1,49 +0,0 @@
package peer
import (
"io"
"net"
"net/netip"
)
type UDPConn interface {
ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error)
WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error)
}
type ifWriter io.Writer
type udpReader interface {
ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error)
}
type udpWriter interface {
WriteToUDPAddrPort([]byte, netip.AddrPort) (int, error)
}
type mcUDPWriter interface {
WriteToUDP([]byte, *net.UDPAddr) (int, error)
}
type Marshaller interface {
Marshal([]byte) []byte
}
type dataPacketSender interface {
SendDataPacket(pkt []byte, peer *remotePeer)
RelayDataPacket(pkt []byte, peer, relay *remotePeer)
}
type controlPacketSender interface {
SendControlPacket(pkt Marshaller, peer *remotePeer)
RelayControlPacket(pkt Marshaller, peer, relay *remotePeer)
}
type encryptedPacketSender interface {
SendEncryptedDataPacket(pkt []byte, peer *remotePeer)
}
type controlMsgHandler interface {
HandleControlMsg(pkt any)
}

View File

@ -6,7 +6,7 @@ import (
)
func Main() {
conf := Config{}
conf := peerConfig{}
flag.StringVar(&conf.NetName, "name", "", "[REQUIRED] The network name.")
flag.StringVar(&conf.HubAddress, "hub-address", "", "[REQUIRED] The hub address.")
@ -18,6 +18,6 @@ func Main() {
os.Exit(1)
}
peer := New(conf)
peer := newPeerMain(conf)
peer.Run()
}

View File

@ -1,10 +1,6 @@
package peer
import (
"log"
"sync/atomic"
)
/*
type mcReader struct {
conn udpReader
super controlMsgHandler
@ -55,3 +51,4 @@ func (r *mcReader) handleNextPacket() {
SrcAddr: remoteAddr,
})
}
*/

View File

@ -1,8 +1,6 @@
package peer
import (
"log"
"golang.org/x/crypto/nacl/sign"
)
@ -34,7 +32,9 @@ func verifyLocalDiscoveryPacket(pkt, buf []byte, pubSignKey []byte) bool {
// ----------------------------------------------------------------------------
/*
type mcWriter struct {
conn mcUDPWriter
discoveryPacket []byte
}
@ -50,4 +50,4 @@ func (w *mcWriter) SendLocalDiscovery() {
if _, err := w.conn.WriteToUDP(w.discoveryPacket, multicastAddr); err != nil {
log.Printf("[MCWriter] Failed to write multicast UDP packet: %v", err)
}
}
}*/

View File

@ -1,11 +1,6 @@
package peer
import (
"bytes"
"net"
"testing"
)
/*
// ----------------------------------------------------------------------------
// Testing that we can create and verify a local discovery packet.
@ -100,3 +95,4 @@ func TestMCWriter_SendLocalDiscovery(t *testing.T) {
t.Fatal("Verification should succeed.")
}
}
*/

View File

@ -15,21 +15,21 @@ import (
"vppn/m"
)
type Peer struct {
ifReader *IFReader
connReader *ConnReader
type peerMain struct {
ifReader *ifReader
connReader *connReader
iface io.Writer
hubPoller *hubPoller
super *Super
super *supervisor
}
type Config struct {
type peerConfig struct {
NetName string
HubAddress string
APIKey string
}
func New(conf Config) *Peer {
func newPeerMain(conf peerConfig) *peerMain {
config, err := loadPeerConfig(conf.NetName)
if err != nil {
log.Printf("Failed to load configuration: %v", err)
@ -83,15 +83,15 @@ func New(conf Config) *Peer {
rtPtr := &atomic.Pointer[routingTable]{}
rtPtr.Store(&rt)
ifReader := NewIFReader(iface, writeToUDPAddrPort, rtPtr)
super := NewSuper(writeToUDPAddrPort, rtPtr, config.PrivKey)
connReader := NewConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr)
ifReader := newIFReader(iface, writeToUDPAddrPort, rtPtr)
super := newSupervisor(writeToUDPAddrPort, rtPtr, config.PrivKey)
connReader := newConnReader(conn.ReadFromUDPAddrPort, writeToUDPAddrPort, iface, super.HandleControlMsg, rtPtr)
hubPoller, err := newHubPoller(config.PeerIP, conf.NetName, conf.HubAddress, conf.APIKey, super.HandleControlMsg)
if err != nil {
log.Fatalf("Failed to create hub poller: %v", err)
}
return &Peer{
return &peerMain{
iface: iface,
ifReader: ifReader,
connReader: connReader,
@ -100,14 +100,14 @@ func New(conf Config) *Peer {
}
}
func (p *Peer) Run() {
func (p *peerMain) Run() {
go p.ifReader.Run()
go p.connReader.Run()
p.super.Start()
p.hubPoller.Run()
}
func initPeerWithHub(conf Config) {
func initPeerWithHub(conf peerConfig) {
keys := generateKeys()
initURL, err := url.Parse(conf.HubAddress)

View File

@ -14,8 +14,8 @@ type P struct {
RT *atomic.Pointer[routingTable]
Conn *TestUDPConn
IFace *TestIFace
ConnReader *ConnReader
IFReader *IFReader
ConnReader *connReader
IFReader *ifReader
}
func NewPeerForTesting(n *TestNetwork, ip byte, addr netip.AddrPort) P {

View File

@ -25,7 +25,7 @@ type peerState interface {
type pState struct {
// Output.
publish func(remotePeer)
sendControlPacket func(remotePeer, Marshaller)
sendControlPacket func(remotePeer, marshaller)
// Immutable data.
localIP byte
@ -124,7 +124,7 @@ func (s *pState) logf(format string, args ...any) {
// ----------------------------------------------------------------------------
func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) {
func (s *pState) SendTo(pkt marshaller, addr netip.AddrPort) {
if !addr.IsValid() {
return
}
@ -134,7 +134,7 @@ func (s *pState) SendTo(pkt Marshaller, addr netip.AddrPort) {
s.Send(route, pkt)
}
func (s *pState) Send(peer remotePeer, pkt Marshaller) {
func (s *pState) Send(peer remotePeer, pkt marshaller) {
if err := s.limiter.Limit(); err != nil {
s.logf("Rate limited.")
return

View File

@ -31,7 +31,7 @@ func NewPeerStateTestHarness() *PeerStateTestHarness {
publish: func(rp remotePeer) {
h.Published = rp
},
sendControlPacket: func(rp remotePeer, pkt Marshaller) {
sendControlPacket: func(rp remotePeer, pkt marshaller) {
h.Sent = append(h.Sent, PeerStateControlMsg{rp, pkt})
},
localIP: 2,

View File

@ -11,26 +11,26 @@ import (
"git.crumpington.com/lib/go/ratelimiter"
)
type Super struct {
type supervisor struct {
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error)
staged routingTable
shared *atomic.Pointer[routingTable]
peers [256]*PeerSuper
peers [256]*peerSuper
lock sync.Mutex
buf1 []byte
buf2 []byte
}
func NewSuper(
func newSupervisor(
writeToUDPAddrPort func([]byte, netip.AddrPort) (int, error),
rt *atomic.Pointer[routingTable],
privKey []byte,
) *Super {
) *supervisor {
routes := rt.Load()
s := &Super{
s := &supervisor{
writeToUDPAddrPort: writeToUDPAddrPort,
staged: *routes,
shared: rt,
@ -55,23 +55,23 @@ func NewSuper(
MaxWaitCount: 1,
}),
}
s.peers[i] = NewPeerSuper(state)
s.peers[i] = newPeerSuper(state)
}
return s
}
func (s *Super) Start() {
func (s *supervisor) Start() {
for i := range s.peers {
go s.peers[i].Run()
}
}
func (s *Super) HandleControlMsg(destIP byte, msg any) {
func (s *supervisor) HandleControlMsg(destIP byte, msg any) {
s.peers[destIP].HandleControlMsg(msg)
}
func (s *Super) send(peer remotePeer, pkt Marshaller) {
func (s *supervisor) send(peer remotePeer, pkt marshaller) {
s.lock.Lock()
defer s.lock.Unlock()
@ -90,7 +90,7 @@ func (s *Super) send(peer remotePeer, pkt Marshaller) {
s.writeToUDPAddrPort(enc, relay.DirectAddr)
}
func (s *Super) publish(rp remotePeer) {
func (s *supervisor) publish(rp remotePeer) {
s.lock.Lock()
defer s.lock.Unlock()
@ -100,7 +100,7 @@ func (s *Super) publish(rp remotePeer) {
s.shared.Store(&copy)
}
func (s *Super) ensureRelay() {
func (s *supervisor) ensureRelay() {
if _, ok := s.staged.GetRelay(); ok {
return
}
@ -116,26 +116,26 @@ func (s *Super) ensureRelay() {
// ----------------------------------------------------------------------------
type PeerSuper struct {
type peerSuper struct {
messages chan any
state peerState
}
func NewPeerSuper(state *pState) *PeerSuper {
return &PeerSuper{
func newPeerSuper(state *pState) *peerSuper {
return &peerSuper{
messages: make(chan any, 8),
state: state.OnPeerUpdate(nil),
}
}
func (s *PeerSuper) HandleControlMsg(msg any) {
func (s *peerSuper) HandleControlMsg(msg any) {
select {
case s.messages <- msg:
default:
}
}
func (s *PeerSuper) Run() {
func (s *peerSuper) Run() {
go func() {
// Randomize ping timers.
time.Sleep(time.Duration(rand.Intn(4000)) * time.Millisecond)

View File

@ -7,7 +7,7 @@ import (
)
// TODO: Remove
func NewRemotePeer(ip byte) *remotePeer {
func newRemotePeer(ip byte) *remotePeer {
counter := uint64(time.Now().Unix()<<30 + 1)
return &remotePeer{
IP: ip,
@ -58,7 +58,7 @@ func (p remotePeer) DecryptDataPacket(h header, enc, out []byte) ([]byte, error)
}
// Peer must have a ControlCipher.
func (p remotePeer) EncryptControlPacket(pkt Marshaller, tmp, out []byte) []byte {
func (p remotePeer) EncryptControlPacket(pkt marshaller, tmp, out []byte) []byte {
tmp = pkt.Marshal(tmp)
h := header{
StreamID: controlStreamID,