wip: client/server working.

This commit is contained in:
jdl 2024-12-20 21:06:16 +01:00
parent 5b34b3311b
commit c7d3fe1ed8
8 changed files with 354 additions and 317 deletions

View File

@ -5,6 +5,7 @@ import (
"log"
"net"
"net/netip"
"runtime/debug"
"sync"
)
@ -22,6 +23,7 @@ func newConnWriter(conn *net.UDPConn) *connWriter {
func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) {
w.lock.Lock()
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
debug.PrintStack()
log.Fatalf("Failed to write to UDP port: %v", err)
}
w.lock.Unlock()

View File

@ -10,7 +10,7 @@ const (
controlHeaderSize = 24
dataStreamID = 1
dataHeaderSize = 12
forwardStreamID = 3
relayStreamID = 3
)
type header struct {

View File

@ -114,7 +114,6 @@ func main(netName, listenIP string, port uint16) {
go newHubPoller(netName, conf, peers).Run()
go readFromConn(conf.PeerIP, conn, peers)
readFromIFace(iface, peers)
}
// ----------------------------------------------------------------------------
@ -157,12 +156,7 @@ func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) {
}
h.Parse(data)
if h.DestIP == localIP {
peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
} else {
peers[h.DestIP].ForwardPacket(data)
}
peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
}
}
@ -183,6 +177,6 @@ func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) {
log.Fatalf("Failed to read from interface: %v", err)
}
peers[remoteIP].SendData(packet)
peers[remoteIP].HandleInterfacePacket(packet)
}
}

View File

@ -7,7 +7,10 @@ import (
"unsafe"
)
var errMalformedPacket = errors.New("malformed packet")
var (
errMalformedPacket = errors.New("malformed packet")
errUnknownPacketType = errors.New("unknown packet type")
)
const (
packetTypePing = iota + 1
@ -22,6 +25,18 @@ type controlPacket struct {
Payload any
}
func (p *controlPacket) ParsePayload(buf []byte) (err error) {
switch buf[0] {
case packetTypePing:
p.Payload, err = parsePingPacket(buf)
case packetTypePong:
p.Payload, err = parsePongPacket(buf)
default:
return errUnknownPacketType
}
return err
}
// ----------------------------------------------------------------------------
// A pingPacket is sent from a node acting as a client, to a node acting
@ -32,9 +47,9 @@ type pingPacket struct {
SharedKey [32]byte
}
func newPingPacket(sharedKey []byte) (pp pingPacket) {
func newPingPacket(sharedKey [32]byte) (pp pingPacket) {
pp.SentAt = time.Now().UnixMilli()
copy(pp.SharedKey[:], sharedKey)
copy(pp.SharedKey[:], sharedKey[:])
return
}

View File

@ -12,7 +12,7 @@ func TestPacketPing(t *testing.T) {
buf := make([]byte, bufferSize)
p := newPingPacket(sharedKey)
p := newPingPacket([32]byte(sharedKey))
out := p.Marshal(buf)
p2, err := parsePingPacket(out)

214
node/peer-states.go Normal file
View File

@ -0,0 +1,214 @@
package node
import (
"fmt"
"log"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
type peerState interface {
Name() string
OnPeerUpdate(*m.Peer) peerState
OnPing(netip.AddrPort, pingPacket) peerState
OnPong(netip.AddrPort, pongPacket) peerState
OnPingTimer() peerState
OnTimeoutTimer() peerState
}
// ----------------------------------------------------------------------------
type stateBase struct {
// The purpose of this state machine is to manage this published data.
published *atomic.Pointer[peerData]
// The other remote peers.
peers *remotePeers
// Immutable data.
localIP byte
localPub bool
remoteIP byte
privKey []byte
conn *connWriter
// For sending to peer.
counter *uint64
// Mutable peer data.
peer *m.Peer
remotePub bool
data peerData // Local copy of shared data. See publish().
// Timers
pingTimer *time.Timer
timeoutTimer *time.Timer
buf []byte
encBuf []byte
}
func (sb *stateBase) Name() string { return "idle" }
func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
// Both nil: no change.
if peer == nil && s.peer == nil {
return nil
}
// No change.
if peer != nil && s.peer != nil && s.peer.Version == peer.Version {
return nil
}
s.peer = peer
s.data = peerData{}
s.data.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
ip, isValid := netip.AddrFromSlice(peer.PublicIP)
if isValid {
s.remotePub = true
s.data.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
s.data.relay = peer.Mediator
if s.localPub && s.localIP < s.remoteIP {
return newStateServer(s)
}
return newStateClient(s)
}
if s.localPub {
return newStateServer(s)
}
// TODO: return newStateMediated(a/b)
return nil
}
func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil }
func (s *stateBase) OnPingTimer() peerState { return nil }
func (s *stateBase) OnTimeoutTimer() peerState { return nil }
// Helpers.
func (s *stateBase) resetPingTimer() { s.pingTimer.Reset(pingInterval) }
func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) }
func (s *stateBase) stopPingTimer() { s.pingTimer.Stop() }
func (s *stateBase) stopTimeoutTimer() { s.timeoutTimer.Stop() }
func (s *stateBase) logf(msg string, args ...any) {
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
}
func (s *stateBase) publish() {
data := s.data
s.published.Store(&data)
}
func (s *stateBase) sendPing(sharedKey [32]byte) {
s.sendControlPacket(newPingPacket(sharedKey))
}
func (s *stateBase) sendPong(ping pingPacket) {
s.sendControlPacket(newPongPacket(ping.SentAt))
}
func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
buf := pkt.Marshal(s.buf)
h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(s.counter, 1),
SourceIP: s.localIP,
DestIP: s.remoteIP,
}
buf = s.data.controlCipher.Encrypt(h, buf, s.encBuf)
if s.data.relayIP == 0 {
s.conn.WriteTo(buf, s.data.remoteAddr)
return
}
// TODO: Relay!
}
// ----------------------------------------------------------------------------
type stateClient struct {
sharedKey [32]byte
*stateBase
}
func newStateClient(b *stateBase) peerState {
s := &stateClient{stateBase: b}
s.publish()
s.data.dataCipher = newDataCipher()
s.sharedKey = s.data.dataCipher.Key()
s.sendPing(s.sharedKey)
s.resetPingTimer()
s.resetTimeoutTimer()
return s
}
func (s *stateClient) Name() string { return "client" }
func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState {
if !s.data.up {
s.data.up = true
s.publish()
}
s.resetTimeoutTimer()
return nil
}
func (s *stateClient) OnPingTimer() peerState {
s.sendPing(s.sharedKey)
s.resetPingTimer()
return nil
}
func (s *stateClient) OnTimeoutTimer() peerState {
s.data.up = false
s.publish()
return nil
}
// ----------------------------------------------------------------------------
type stateServer struct {
*stateBase
}
func newStateServer(b *stateBase) peerState {
s := &stateServer{b}
s.publish()
s.stopPingTimer()
s.stopTimeoutTimer()
return s
}
func (s *stateServer) Name() string { return "server" }
func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState {
if addr != s.data.remoteAddr {
s.logf("Got new peer address: %v", addr)
s.data.remoteAddr = addr
s.data.up = true
s.publish()
}
if s.data.dataCipher == nil || p.SharedKey != s.data.dataCipher.Key() {
s.logf("Got new shared key.")
s.data.dataCipher = newDataCipherFromKey(p.SharedKey)
s.publish()
}
s.sendPong(p)
return nil
}

View File

@ -1,10 +1,6 @@
package node
import (
"log"
"math/rand"
"net/netip"
"sync/atomic"
"time"
"vppn/m"
)
@ -15,263 +11,64 @@ const (
timeoutInterval = 20 * time.Second
)
type stateFunc func() stateFunc
type peerSuper struct {
*remotePeer
peer *m.Peer
remotePublic bool
peerData peerData
pktBuf []byte
encBuf []byte
}
func newPeerSuper(rp *remotePeer) *peerSuper {
return &peerSuper{
remotePeer: rp,
peer: nil,
pktBuf: make([]byte, bufferSize),
encBuf: make([]byte, bufferSize),
}
}
func (rp *peerSuper) Run() {
func (rp *remotePeer) supervise(
conf m.PeerConfig,
remoteIP byte,
conn *connWriter,
peers *remotePeers,
) {
defer panicHandler()
state := rp.stateInit
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateInit() stateFunc {
//rp.logf("STATE: Init")
x := peerData{}
rp.shared.Store(&x)
rp.peerData.relay = false
rp.peerData.controlCipher = nil
rp.peerData.dataCipher = nil
rp.peerData.remoteAddr = zeroAddrPort
rp.peerData.relayIP = 0
if rp.peer == nil {
return rp.stateDisconnected
base := &stateBase{
published: rp.published,
peers: rp.peers,
localIP: rp.localIP,
remoteIP: rp.remoteIP,
privKey: conf.EncPrivKey,
localPub: addrIsValid(conf.PublicIP),
conn: rp.conn,
counter: &rp.counter,
pingTimer: time.NewTimer(time.Second),
timeoutTimer: time.NewTimer(time.Second),
buf: make([]byte, bufferSize),
encBuf: make([]byte, bufferSize),
}
var addr netip.Addr
addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
if rp.remotePublic {
rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
} else {
rp.peerData.relay = false
}
rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
return rp.stateSelectRole()
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateDisconnected() stateFunc {
//rp.logf("STATE: Disconnected")
for {
select {
case <-rp.controlPackets:
// Drop
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateSelectRole() stateFunc {
rp.logf("STATE: SelectRole")
if !rp.localPublic && !rp.remotePublic {
return rp.stateSelectMediator
}
if !rp.localPublic {
return rp.stateServer
} else if !rp.remotePublic {
return rp.stateClient
}
if rp.localIP < rp.remoteIP {
return rp.stateClient
}
return rp.stateServer
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) stateSelectMediator() stateFunc {
rp.logf("STATE: SelectMediator")
for {
log.Printf("Selecting mediator...")
if ip := rp.selectMediator(); ip != 0 {
rp.logf("Got mediator: %d", ip)
rp.peerData.relayIP = ip
if rp.localIP < rp.remoteIP {
return rp.stateClient
}
return rp.stateServer
}
select {
case <-time.After(pingInterval):
continue
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
func (rp *peerSuper) selectMediator() byte {
possible := make([]byte, 0, 8)
for _, peer := range rp.peers {
if peer.canRelay() {
rp.logf("relay: %v", peer.shared.Load())
possible = append(possible, peer.remoteIP)
}
}
if len(possible) == 0 {
return 0
}
return possible[rand.Intn(len(possible))]
}
// ----------------------------------------------------------------------------
// The remote is a server.
func (rp *peerSuper) stateServer() stateFunc {
rp.logf("STATE: Server")
rp.peerData.dataCipher = newDataCipher()
rp.updateShared()
base.pingTimer.Stop()
base.timeoutTimer.Stop()
var (
pingTimer = time.NewTimer(pingInterval)
timeoutTimer = time.NewTimer(timeoutInterval)
ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
curState peerState = base
nextState peerState
)
defer pingTimer.Stop()
defer timeoutTimer.Stop()
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
for {
nextState = nil
select {
case <-pingTimer.C:
ping.SentAt = time.Now().UnixMilli()
rp.sendControlPacket(ping)
pingTimer.Reset(pingInterval)
case peer := <-rp.peerUpdates:
nextState = curState.OnPeerUpdate(peer)
case cPkt := <-rp.controlPackets:
if _, ok := cPkt.Payload.(pongPacket); ok {
timeoutTimer.Reset(timeoutInterval)
case pkt := <-rp.controlPackets:
switch p := pkt.Payload.(type) {
case pingPacket:
nextState = curState.OnPing(pkt.RemoteAddr, p)
case pongPacket:
nextState = curState.OnPong(pkt.RemoteAddr, p)
default:
// Unknown packet type.
}
case <-timeoutTimer.C:
if rp.peerData.relayIP != 0 {
rp.logf("Timeout (server, relay)")
return rp.stateSelectMediator
} else {
rp.logf("Timeout (server)")
}
case <-base.pingTimer.C:
nextState = curState.OnPingTimer()
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
case <-base.timeoutTimer.C:
nextState = curState.OnTimeoutTimer()
}
if nextState != nil {
rp.logf("%s --> %s", curState.Name(), nextState.Name())
curState = nextState
}
}
}
// ----------------------------------------------------------------------------
// The remote is a client.
func (rp *peerSuper) stateClient() stateFunc {
rp.logf("STATE: Client")
rp.updateShared()
var (
currentKey = [32]byte{}
timeoutTimer = time.NewTimer(timeoutInterval)
)
defer timeoutTimer.Stop()
for {
select {
case cPkt := <-rp.controlPackets:
if cPkt.RemoteAddr != rp.peerData.remoteAddr {
rp.peerData.remoteAddr = cPkt.RemoteAddr
rp.logf("Got new remote address: %v", cPkt.RemoteAddr)
rp.updateShared()
}
ping, ok := cPkt.Payload.(pingPacket)
if !ok {
continue
}
if ping.SharedKey != currentKey {
rp.logf("Connected with new shared key")
currentKey = ping.SharedKey
rp.peerData.up = true
rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
rp.updateShared()
}
timeoutTimer.Reset(timeoutInterval)
rp.sendControlPacket(newPongPacket(ping.SentAt))
case <-timeoutTimer.C:
if rp.peerData.relayIP != 0 {
rp.logf("Timeout (server, relay)")
return rp.stateSelectMediator
} else {
rp.logf("Timeout (server)")
}
case rp.peer = <-rp.peerUpdates:
return rp.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) updateShared() {
data := rp.peerData
rp.shared.Store(&data)
}
// ----------------------------------------------------------------------------
func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
buf := pkt.Marshal(rp.pktBuf)
h := header{
StreamID: controlStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
}
buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
if rp.peerData.relayIP == 0 {
rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
return
}
rp.peers[rp.peerData.relayIP].RelayControlData(buf)
}

View File

@ -22,19 +22,14 @@ type peerData struct {
type remotePeer struct {
// Immutable data.
localIP byte
remoteIP byte
privKey []byte
localPublic bool // True if local node is public.
iface *ifWriter
conn *connWriter
localIP byte
remoteIP byte
iface *ifWriter
conn *connWriter
// Shared state.
peers *remotePeers
shared *atomic.Pointer[peerData]
// Only used in HandlePeerUpdate.
peerVersion int64
peers *remotePeers
published *atomic.Pointer[peerData]
// Only used in HandlePacket / Not synchronized.
dupCheck *dupCheck
@ -55,12 +50,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
rp := &remotePeer{
localIP: conf.PeerIP,
remoteIP: remoteIP,
privKey: conf.EncPrivKey,
localPublic: addrIsValid(conf.PublicIP),
iface: iface,
conn: conn,
peers: peers,
shared: &atomic.Pointer[peerData]{},
published: &atomic.Pointer[peerData]{},
dupCheck: newDupCheck(0),
decryptBuf: make([]byte, bufferSize),
encryptBuf: make([]byte, bufferSize),
@ -70,10 +63,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
}
pd := peerData{}
rp.shared.Store(&pd)
go newPeerSuper(rp).Run()
rp.published.Store(&pd)
//go newPeerSuper(rp).Run()
go rp.supervise(conf, remoteIP, conn, peers)
return rp
}
@ -82,10 +75,7 @@ func (rp *remotePeer) logf(msg string, args ...any) {
}
func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
if peer != nil && peer.Version != rp.peerVersion {
rp.peerUpdates <- peer
rp.peerVersion = peer.Version
}
rp.peerUpdates <- peer
}
// ----------------------------------------------------------------------------
@ -101,6 +91,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
case dataStreamID:
rp.handleDataPacket(data)
case relayStreamID:
rp.handleRelayPacket(h, data)
default:
rp.logf("Unknown stream ID: %d", h.StreamID)
}
@ -109,8 +102,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
shared := rp.shared.Load()
shared := rp.published.Load()
if shared.controlCipher == nil {
log.Printf("Shared: %+v", *shared)
rp.logf("Not connected (control).")
return
}
@ -141,19 +135,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
RemoteAddr: addr,
}
var err error
switch out[0] {
case packetTypePing:
pkt.Payload, err = parsePingPacket(out)
case packetTypePong:
pkt.Payload, err = parsePongPacket(out)
default:
rp.logf("Unknown control packet type: %d", out[0])
return
}
if err != nil {
if err := pkt.ParsePayload(out); err != nil {
rp.logf("Failed to parse control packet: %v", err)
return
}
@ -168,7 +150,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleDataPacket(data []byte) {
shared := rp.shared.Load()
shared := rp.published.Load()
if shared.dataCipher == nil {
rp.logf("Not connected (recv).")
return
@ -185,34 +167,65 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
// ----------------------------------------------------------------------------
func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
shared := rp.published.Load()
if shared.dataCipher == nil {
rp.logf("Not connected (recv).")
return
}
dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
if !ok {
rp.logf("Failed to decrypt data packet.")
return
}
rp.peers[h.DestIP].sendDirect(dec)
}
// ----------------------------------------------------------------------------
// SendData sends data coming from the interface going to the network.
//
// This function is called by a single thread.
func (rp *remotePeer) SendData(data []byte) {
rp.sendData(dataStreamID, data)
rp.sendData(dataStreamID, rp.remoteIP, data)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) HandleInterfacePacket(data []byte) {
shared := rp.published.Load()
func (rp *remotePeer) RelayControlData(data []byte) {
rp.sendData(forwardStreamID, data)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) ForwardPacket(data []byte) {
shared := rp.shared.Load()
if shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (forward).")
if shared.dataCipher == nil {
rp.logf("Not connected (handle interface).")
return
}
rp.conn.WriteTo(data, shared.remoteAddr)
h := header{
StreamID: dataStreamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
}
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
if shared.relayIP != 0 {
rp.peers[shared.relayIP].RelayData(shared.relayIP, enc)
} else {
rp.SendData(data)
}
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) sendData(streamID byte, data []byte) {
shared := rp.shared.Load()
func (rp *remotePeer) RelayData(destIP byte, data []byte) {
rp.sendData(relayStreamID, destIP, data)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) {
shared := rp.published.Load()
if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send).")
return
@ -222,16 +235,18 @@ func (rp *remotePeer) sendData(streamID byte, data []byte) {
StreamID: streamID,
Counter: atomic.AddUint64(&rp.counter, 1),
SourceIP: rp.localIP,
DestIP: rp.remoteIP,
DestIP: destIP,
}
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
rp.conn.WriteTo(enc, shared.remoteAddr)
}
// ----------------------------------------------------------------------------
func (rp *remotePeer) canRelay() bool {
shared := rp.shared.Load()
return shared.relay && shared.up
func (rp *remotePeer) sendDirect(data []byte) {
shared := rp.published.Load()
if shared.remoteAddr == zeroAddrPort {
rp.logf("Not connected (send).")
return
}
rp.conn.WriteTo(data, shared.remoteAddr)
}