Client-server working. No mediator.

This commit is contained in:
jdl 2024-12-18 12:35:47 +01:00
parent d0d7bf9b58
commit 6e3c3ec0b2
18 changed files with 944 additions and 116 deletions

View File

@ -2,6 +2,10 @@
## Roadmap
* `node` package
* rename `peerRepo` to `routingTable`
* create router type with `Get(ip) *peer` and `Mediator() *peer` methods
* connReader / Writer should have access to the peerRepo
* Use default port 456
* Remove signing key from hub
* Peer: UDP hole-punching

View File

@ -2,10 +2,10 @@ package main
import (
"log"
"vppn/peer"
"vppn/node"
)
func main() {
log.SetFlags(0)
peer.Main()
node.Main()
}

View File

@ -4,30 +4,26 @@ import (
"log"
"net"
"net/netip"
"sync"
"sync/atomic"
"vppn/fasttime"
)
// TODO:
type connRouter interface {
Lookup(byte) *peer
Mediator() *peer
}
type connWriter struct {
*net.UDPConn
lock sync.Mutex
localIP byte
buf []byte
counters [256]uint64
lookup func(byte) *peer
routing *routingTable
}
func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connWriter {
func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter {
w := &connWriter{
UDPConn: conn,
localIP: localIP,
buf: make([]byte, bufferSize),
lookup: lookup,
routing: routing,
}
for i := range w.counters {
@ -37,24 +33,36 @@ func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *co
return w
}
func (w *connWriter) WriteTo(remoteIP, packetType byte, data []byte) error {
peer := w.lookup(remoteIP)
func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) error {
// TODO: Handle mediator.
peer := w.routing.Get(remoteIP)
if peer == nil || peer.Addr == nil {
log.Printf("No peer: %d", remoteIP)
return nil
}
if stream == streamData && !peer.Up {
log.Printf("Peer down: %d", remoteIP)
}
return w.WriteToPeer(peer, stream, data)
}
func (w *connWriter) WriteToPeer(peer *peer, stream byte, data []byte) error {
w.lock.Lock()
remoteIP := peer.IP
h := header{
Counter: atomic.AddUint64(&w.counters[remoteIP], 1),
SourceIP: w.localIP,
ViaIP: 0,
DestIP: remoteIP,
PacketType: packetType,
Counter: atomic.AddUint64(&w.counters[remoteIP], 1),
SourceIP: w.localIP,
ViaIP: 0,
DestIP: remoteIP,
Stream: stream,
}
buf := encryptPacket(&h, peer.SharedKey, data, w.buf)
_, err := w.WriteToUDPAddrPort(buf, *peer.Addr)
w.lock.Unlock()
return err
}
@ -64,15 +72,15 @@ type connReader struct {
*net.UDPConn
localIP byte
dupChecks [256]*dupCheck
lookup func(byte) *peer
routing *routingTable
buf []byte
}
func newConnReader(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connReader {
func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader {
r := &connReader{
UDPConn: conn,
localIP: localIP,
lookup: lookup,
routing: routing,
buf: make([]byte, bufferSize),
}
for i := range r.dupChecks {
@ -93,18 +101,20 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data
data = buf[:n]
if n < headerSize {
log.Printf("Dropping short packet: %d", n)
continue // Packet it soo short.
}
h.Parse(data)
if len(data) != headerSize+int(h.DataSize) {
log.Printf("Incorrect size")
continue // Packet is corrupt.
log.Printf("Malformed packet: %d != %d", len(data), headerSize+int(h.DataSize))
continue
}
peer := r.lookup(h.SourceIP)
peer := r.routing.Get(h.SourceIP)
if peer == nil {
log.Printf("No peer...")
log.Printf("No peer: %d...", h.SourceIP)
continue
}
@ -117,7 +127,7 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data
out, data = data, out
if r.dupChecks[h.SourceIP].IsDup(h.Counter) {
log.Printf("Duplicate...")
log.Printf("Duplicate: %d", h.Counter)
continue
}

View File

@ -1,6 +1,11 @@
package node
import "golang.org/x/crypto/nacl/box"
import (
"sync"
"vppn/fasttime"
"golang.org/x/crypto/nacl/box"
)
// Encrypting the packet will also set the header's DataSize field.
func encryptPacket(h *header, sharedKey, data, out []byte) []byte {
@ -24,3 +29,23 @@ func computeSharedKey(peerPubKey, privKey []byte) []byte {
box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey))
return shared[:]
}
var (
traceIDLock sync.Mutex
traceIDTime uint64
traceIDCounter uint64
)
func newTraceID() (id uint64) {
traceIDLock.Lock()
defer traceIDLock.Unlock()
now := uint64(fasttime.Now())
if traceIDTime < now {
traceIDTime = now
traceIDCounter = 0
}
traceIDCounter++
return traceIDTime<<30 + traceIDCounter
}

View File

@ -33,11 +33,11 @@ func TestEncryptDecryptPacket(t *testing.T) {
rand.Read(original)
h := header{
Counter: 2893749238,
SourceIP: 5,
ViaIP: 8,
DestIP: 12,
PacketType: 32,
Counter: 2893749238,
SourceIP: 5,
ViaIP: 8,
DestIP: 12,
Stream: 1,
}
encrypted := make([]byte, bufferSize)
@ -62,7 +62,6 @@ func TestEncryptDecryptPacket(t *testing.T) {
}
}
/*
func BenchmarkEncryptPacket(b *testing.B) {
_, privKey1, err := box.GenerateKey(rand.Reader)
if err != nil {
@ -77,16 +76,24 @@ func BenchmarkEncryptPacket(b *testing.B) {
sharedEncKey := [32]byte{}
box.Precompute(&sharedEncKey, pubKey2, privKey1)
original := make([]byte, MTU)
original := make([]byte, if_mtu)
rand.Read(original)
nonce := make([]byte, NONCE_SIZE)
nonce := make([]byte, headerSize)
rand.Read(nonce)
encrypted := make([]byte, BUFFER_SIZE)
encrypted := make([]byte, bufferSize)
h := header{
Counter: 2893749238,
SourceIP: 5,
ViaIP: 8,
DestIP: 12,
Stream: 1,
}
for i := 0; i < b.N; i++ {
encrypted = encryptPacket(sharedEncKey[:], nonce, original, encrypted)
encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted)
}
}
@ -107,18 +114,27 @@ func BenchmarkDecryptPacket(b *testing.B) {
sharedDecKey := [32]byte{}
box.Precompute(&sharedDecKey, pubKey1, privKey2)
original := make([]byte, MTU)
original := make([]byte, if_mtu)
rand.Read(original)
nonce := make([]byte, NONCE_SIZE)
nonce := make([]byte, headerSize)
rand.Read(nonce)
encrypted := make([]byte, BUFFER_SIZE)
encrypted = encryptPacket(sharedEncKey[:], nonce, original, encrypted)
h := header{
Counter: 2893749238,
SourceIP: 5,
ViaIP: 8,
DestIP: 12,
Stream: 1,
}
decrypted := make([]byte, MTU)
encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize))
decrypted := make([]byte, bufferSize)
var ok bool
for i := 0; i < b.N; i++ {
decrypted, _ = decryptPacket(sharedDecKey[:], encrypted, decrypted)
decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted)
if !ok {
panic(ok)
}
}
}
*/

View File

@ -1,5 +1,7 @@
package node
import "log"
type dupCheck struct {
bitSet
head int
@ -20,6 +22,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
// Before head => it's late, say it's a dup.
if counter < dc.headCounter {
log.Printf("Late: %d", counter)
return true
}
@ -27,6 +30,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool {
if counter < dc.tailCounter {
index := (int(counter-dc.headCounter) + dc.head) % bitSetSize
if dc.Get(index) {
log.Printf("Dup: %d, %d", counter, dc.tailCounter)
return true
}

82
node/files.go Normal file
View File

@ -0,0 +1,82 @@
package node
import (
"encoding/json"
"log"
"os"
"path/filepath"
"vppn/m"
)
func configDir(netName string) string {
d, err := os.UserHomeDir()
if err != nil {
log.Fatalf("Failed to get user home directory: %v", err)
}
return filepath.Join(d, ".vppn", netName)
}
func peerConfigPath(netName string) string {
return filepath.Join(configDir(netName), "peer-config.json")
}
func peerStatePath(netName string) string {
return filepath.Join(configDir(netName), "peer-state.json")
}
func storeJson(x any, outPath string) error {
outDir := filepath.Dir(outPath)
_ = os.MkdirAll(outDir, 0700)
tmpPath := outPath + ".tmp"
buf, err := json.Marshal(x)
if err != nil {
return err
}
f, err := os.Create(tmpPath)
if err != nil {
return err
}
if _, err := f.Write(buf); err != nil {
f.Close()
return err
}
if err := f.Sync(); err != nil {
f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
return os.Rename(tmpPath, outPath)
}
func storePeerConfig(netName string, pc m.PeerConfig) error {
return storeJson(pc, peerConfigPath(netName))
}
func storeNetworkState(netName string, ps m.NetworkState) error {
return storeJson(ps, peerStatePath(netName))
}
func loadJson(dataPath string, ptr any) error {
data, err := os.ReadFile(dataPath)
if err != nil {
return err
}
return json.Unmarshal(data, ptr)
}
func loadPeerConfig(netName string) (pc m.PeerConfig, err error) {
return pc, loadJson(peerConfigPath(netName), &pc)
}
func loadNetworkState(netName string) (ps m.NetworkState, err error) {
return ps, loadJson(peerStatePath(netName), &ps)
}

View File

@ -2,15 +2,19 @@ package node
import "unsafe"
const headerSize = 24
const (
headerSize = 24
streamData = 1
streamRouting = 2
)
type header struct {
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte
ViaIP byte
DestIP byte
PacketType byte // The packet type. See PACKET_* constants.
DataSize uint16 // Data size following associated data.
Counter uint64 // Init with fasttime.Now() << 30 to ensure monotonic.
SourceIP byte
ViaIP byte
DestIP byte
Stream byte // See stream* constants.
DataSize uint16 // Data size following associated data.
}
func (hdr *header) Parse(nb []byte) {
@ -18,7 +22,7 @@ func (hdr *header) Parse(nb []byte) {
hdr.SourceIP = nb[8]
hdr.ViaIP = nb[9]
hdr.DestIP = nb[10]
hdr.PacketType = nb[11]
hdr.Stream = nb[11]
hdr.DataSize = *(*uint16)(unsafe.Pointer(&nb[12]))
}
@ -27,6 +31,6 @@ func (hdr header) Marshal(buf []byte) {
buf[8] = hdr.SourceIP
buf[9] = hdr.ViaIP
buf[10] = hdr.DestIP
buf[11] = hdr.PacketType
buf[11] = hdr.Stream
*(*uint16)(unsafe.Pointer(&buf[12])) = hdr.DataSize
}

View File

@ -4,12 +4,12 @@ import "testing"
func TestHeaderMarshalParse(t *testing.T) {
nIn := header{
Counter: 3212,
SourceIP: 34,
ViaIP: 20,
DestIP: 200,
PacketType: 44,
DataSize: 1235,
Counter: 3212,
SourceIP: 34,
ViaIP: 20,
DestIP: 200,
Stream: 44,
DataSize: 1235,
}
buf := make([]byte, headerSize)

View File

@ -3,6 +3,7 @@ package node
import (
"fmt"
"io"
"log"
"net"
"os"
"syscall"
@ -22,23 +23,27 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error)
return nil, ip, err
}
if n < 20 {
continue // Packet too short.
}
buf = buf[:n]
version = buf[0] >> 4
switch version {
case 4:
if n < 20 {
log.Printf("Short IPv4 packet: %d", len(buf))
continue
}
ip = buf[19]
case 6:
if len(buf) < 40 {
continue // Packet too short.
log.Printf("Short IPv6 packet: %d", len(buf))
continue
}
ip = buf[39]
default:
continue // Invalid version.
log.Printf("Invalid IP packet version: %v", version)
continue
}
return buf, ip, nil
@ -47,7 +52,7 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error)
const (
if_mtu = 1200
if_queue_len = 1000
if_queue_len = 2048
)
func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) {

190
node/main.go Normal file
View File

@ -0,0 +1,190 @@
package node
import (
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"net/netip"
"os"
"runtime/debug"
"vppn/m"
)
func panicHandler() {
if r := recover(); r != nil {
log.Fatalf("\n %v\n\nstacktrace from panic: %s\n", r, string(debug.Stack()))
}
}
func Main() {
defer panicHandler()
var (
netName string
initURL string
listenIP string
port int
)
flag.StringVar(&netName, "name", "", "[REQUIRED] The network name.")
flag.StringVar(&initURL, "init-url", "", "Initializes peer from the hub URL.")
flag.StringVar(&listenIP, "listen-ip", "", "IP address to listen on.")
flag.IntVar(&port, "port", 0, "Port to listen on.")
flag.Parse()
if netName == "" {
flag.Usage()
os.Exit(1)
}
if initURL != "" {
mainInit(netName, initURL)
return
}
main(netName, listenIP, uint16(port))
}
func mainInit(netName, initURL string) {
if _, err := loadPeerConfig(netName); err == nil {
log.Fatalf("Network is already initialized.")
}
resp, err := http.Get(initURL)
if err != nil {
log.Fatalf("Failed to fetch data from hub: %v", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatalf("Failed to read response body: %v", err)
}
peerConfig := m.PeerConfig{}
if err := json.Unmarshal(data, &peerConfig); err != nil {
log.Fatalf("Failed to parse configuration: %v", err)
}
if err := storePeerConfig(netName, peerConfig); err != nil {
log.Fatalf("Failed to store configuration: %v", err)
}
log.Print("Initialization successful.")
}
// ----------------------------------------------------------------------------
func main(netName, listenIP string, port uint16) {
conf, err := loadPeerConfig(netName)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
port = determinePort(conf.Port, port)
iface, err := openInterface(conf.Network, conf.PeerIP, netName)
if err != nil {
log.Fatalf("Failed to open interface: %v", err)
}
myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", listenIP, port))
if err != nil {
log.Fatalf("Failed to resolve UDP address: %v", err)
}
conn, err := net.ListenUDP("udp", myAddr)
if err != nil {
log.Fatalf("Failed to open UDP port: %v", err)
}
routing := newRoutingTable()
w := newConnWriter(conn, conf.PeerIP, routing)
r := newConnReader(conn, conf.PeerIP, routing)
router := newRouter(netName, conf, routing, w)
go nodeConnReader(r, w, iface, router)
nodeIFaceReader(w, iface, router)
}
// ----------------------------------------------------------------------------
func determinePort(confPort, portFromCommandLine uint16) uint16 {
if portFromCommandLine != 0 {
return portFromCommandLine
}
if confPort != 0 {
return confPort
}
return 456
}
// ----------------------------------------------------------------------------
func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) {
defer panicHandler()
var (
remoteAddr netip.AddrPort
h header
buf = make([]byte, bufferSize)
data []byte
err error
)
for {
remoteAddr, h, data, err = r.Read(buf)
if err != nil {
log.Fatalf("Failed to read from UDP connection: %v", err)
}
switch h.Stream {
case streamData:
if _, err = iface.Write(data); err != nil {
log.Printf("Malformed data from peer %d: %v", h.SourceIP, err)
}
case streamRouting:
router.HandlePacket(h.SourceIP, remoteAddr, data)
default:
log.Printf("Dropping unknown stream: %d", h.Stream)
}
}
}
// ----------------------------------------------------------------------------
func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) {
var (
buf = make([]byte, bufferSize)
packet []byte
remoteIP byte
err error
)
for {
packet, remoteIP, err = readNextPacket(iface, buf)
if err != nil {
log.Fatalf("Failed to read from interface: %v", err)
}
if remoteIP == w.localIP {
//log.Printf("Incoming packet for self: %x", packet)
//iface.Write(packet)
continue
}
if err := w.WriteTo(remoteIP, streamData, packet); err != nil {
log.Fatalf("Failed to write to network: %v", err)
}
}
}

1
node/node.go Normal file
View File

@ -0,0 +1 @@
package node

View File

@ -1,30 +1 @@
package node
import (
"net/netip"
"sync/atomic"
)
type peer struct {
IP byte
Addr *netip.AddrPort // If we have direct connection, otherwise use mediator.
SharedKey []byte
}
type peerRepo [256]*atomic.Pointer[peer]
func newPeerRepo() peerRepo {
pr := peerRepo{}
for i := range pr {
pr[i] = &atomic.Pointer[peer]{}
}
return pr
}
func (pr peerRepo) Get(ip byte) *peer {
return pr[ip].Load()
}
func (pr *peerRepo) Set(ip byte, p *peer) {
pr[ip].Store(p)
}

1
node/peerstate.go Normal file
View File

@ -0,0 +1 @@
package node

163
node/router.go Normal file
View File

@ -0,0 +1,163 @@
package node
import (
"encoding/json"
"io"
"log"
"net/http"
"net/netip"
"net/url"
"sync/atomic"
"time"
"vppn/m"
)
type peer struct {
Up bool // No data will be sent to peers that are down.
IP byte
Addr *netip.AddrPort // If we have direct connection, otherwise use mediator.
SharedKey []byte
}
// ----------------------------------------------------------------------------
type routingTable struct {
table [256]*atomic.Pointer[peer]
mediator *atomic.Pointer[peer]
}
func newRoutingTable() *routingTable {
r := routingTable{
mediator: &atomic.Pointer[peer]{},
}
for i := range r.table {
r.table[i] = &atomic.Pointer[peer]{}
}
return &r
}
func (r *routingTable) Get(ip byte) *peer {
return r.table[ip].Load()
}
func (r *routingTable) Set(ip byte, p *peer) {
r.table[ip].Store(p)
}
// ----------------------------------------------------------------------------
type router struct {
netName string
*routingTable
peerSupers [256]*peerSupervisor
}
func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w *connWriter) *router {
r := &router{
netName: netName,
routingTable: routingData,
}
for i := range r.peerSupers {
r.peerSupers[i] = newPeerSupervisor(
conf,
byte(i),
w,
r.routingTable)
}
// TODO: Handle Mediator
go r.pollHub(conf)
return r
}
// ----------------------------------------------------------------------------
func (r *router) HandlePacket(sourceIP byte, remoteAddr netip.AddrPort, data []byte) {
p := routingPacket{}
if err := p.Parse(data); err != nil {
log.Printf("Dropping malformed routing packet: %v", err)
return
}
w := routingPacketWrapper{
routingPacket: p,
Addr: remoteAddr,
}
r.peerSupers[sourceIP].HandlePacket(w)
}
// ----------------------------------------------------------------------------
func (r *router) pollHub(conf m.PeerConfig) {
defer panicHandler()
u, err := url.Parse(conf.HubAddress)
if err != nil {
log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err)
}
u.Path = "/peer/fetch-state/"
client := &http.Client{Timeout: 8 * time.Second}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
req.SetBasicAuth("", conf.APIKey)
// TODO: Before we start polling, load state from the file system.
state, err := loadNetworkState(r.netName)
if err != nil {
log.Printf("Failed to load network state: %v", err)
log.Printf("Polling hub...")
r._pollHub(conf, client, req)
} else {
r.applyNetworkState(conf, state)
}
for range time.Tick(64 * time.Second) {
r._pollHub(conf, client, req)
}
}
func (r *router) _pollHub(conf m.PeerConfig, client *http.Client, req *http.Request) {
var state m.NetworkState
log.Printf("Fetching peer state from %s...", conf.HubAddress)
resp, err := client.Do(req)
if err != nil {
log.Printf("Failed to fetch peer state: %v", err)
return
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
log.Printf("Failed to read body from hub: %v", err)
return
}
if err := json.Unmarshal(body, &state); err != nil {
log.Printf("Failed to unmarshal response from hub: %v", err)
return
}
r.applyNetworkState(conf, state)
if err := storeNetworkState(r.netName, state); err != nil {
log.Printf("Failed to store network state: %v", err)
}
}
func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) {
for i := range state.Peers {
if i != int(conf.PeerIP) {
r.peerSupers[i].HandlePeerUpdate(state.Peers[i])
}
}
}

44
node/routingpacket.go Normal file
View File

@ -0,0 +1,44 @@
package node
import (
"errors"
"unsafe"
)
var errMalformedPacket = errors.New("malformed packet")
const (
packetTypeInvalid = iota
// Used to maintain connection.
packetTypePing
packetTypePong
)
type routingPacket struct {
Type byte // One of the packetType* constants.
TraceID uint64 // For matching requests and responses.
}
func newRoutingPacket(reqType byte, traceID uint64) routingPacket {
return routingPacket{
Type: reqType,
TraceID: traceID,
}
}
func (p routingPacket) Marshal(buf []byte) []byte {
buf = buf[:32] // Reserve 32 bytes just in case we need to add anything.
buf[0] = p.Type
*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.TraceID)
return buf
}
func (p *routingPacket) Parse(buf []byte) error {
if len(buf) != 32 {
return errMalformedPacket
}
p.Type = buf[0]
p.TraceID = *(*uint64)(unsafe.Pointer(&buf[1]))
return nil
}

View File

@ -1,14 +1,6 @@
package node
import (
"fmt"
"io"
"log"
"net"
"net/netip"
"runtime/debug"
)
/*
var (
network = []byte{10, 1, 1, 0}
serverIP = byte(1)
@ -30,7 +22,7 @@ func must(err error) {
type TmpNode struct {
network []byte
localIP byte
peers peerRepo
router *router
port uint16
netName string
iface io.ReadWriteCloser
@ -46,7 +38,7 @@ func NewTmpNodeServer() *TmpNode {
n := &TmpNode{
localIP: serverIP,
network: network,
peers: newPeerRepo(),
router: &router{table: newPeerRepo()},
port: port,
netName: netName,
pubKey: pubKey1,
@ -63,10 +55,10 @@ func NewTmpNodeServer() *TmpNode {
conn, err := net.ListenUDP("udp", myAddr)
must(err)
n.w = newConnWriter(conn, n.localIP, n.peers.Get)
n.r = newConnReader(conn, n.localIP, n.peers.Get)
n.w = newConnWriter(conn, n.localIP, n.router)
n.r = newConnReader(conn, n.localIP, n.router)
n.peers.Set(clientIP, &peer{
n.router.table.Set(clientIP, &peer{
IP: clientIP,
SharedKey: computeSharedKey(pubKey2, n.privKey),
})
@ -80,7 +72,7 @@ func NewTmpNodeClient(srvAddrStr string) *TmpNode {
n := &TmpNode{
localIP: clientIP,
network: network,
peers: newPeerRepo(),
router: &router{table: newPeerRepo()},
port: port,
netName: netName,
pubKey: pubKey2,
@ -97,13 +89,13 @@ func NewTmpNodeClient(srvAddrStr string) *TmpNode {
conn, err := net.ListenUDP("udp", myAddr)
must(err)
n.w = newConnWriter(conn, n.localIP, n.peers.Get)
n.r = newConnReader(conn, n.localIP, n.peers.Get)
n.w = newConnWriter(conn, n.localIP, n.router)
n.r = newConnReader(conn, n.localIP, n.router)
serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port))
must(err)
n.peers.Set(serverIP, &peer{
n.router.table.Set(serverIP, &peer{
IP: serverIP,
Addr: &serverAddr,
SharedKey: computeSharedKey(pubKey1, n.privKey),
@ -129,7 +121,7 @@ func (n *TmpNode) RunServer() {
log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr)
must(err)
n.peers.Set(h.SourceIP, &peer{
n.router.table.Set(h.SourceIP, &peer{
IP: h.SourceIP,
Addr: &remoteAddr,
SharedKey: computeSharedKey(pubKey2, n.privKey),
@ -144,7 +136,7 @@ func (n *TmpNode) RunServer() {
func (n *TmpNode) RunClient() {
defer func() {
if r := recover(); r != nil {
fmt.Printf("%v", r)
fmt.Printf("%v\n", r)
debug.PrintStack()
}
}()
@ -184,6 +176,10 @@ func (node *TmpNode) readFromConn() {
// We assume that we're only receiving packets from one source.
_, err = node.iface.Write(packet)
must(err)
if err != nil {
log.Printf("Got error: %v", err)
}
//must(err)
}
}
*/

312
node/tmp_peerstate.go Normal file
View File

@ -0,0 +1,312 @@
package node
import (
"fmt"
"log"
"net/netip"
"time"
"vppn/m"
)
const (
connectTimeout = 6 * time.Second
pingInterval = 6 * time.Second
timeoutInterval = 20 * time.Second
)
type routingPacketWrapper struct {
routingPacket
Addr netip.AddrPort // Source.
}
type peerSupervisor struct {
// Constants:
localIP byte
localPublic bool
remoteIP byte
privKey []byte
// Shared data:
w *connWriter
table *routingTable
packets chan routingPacketWrapper
peerUpdates chan *m.Peer
// Peer-related items.
version int64 // Ony accessed in HandlePeerUpdate.
peer *m.Peer
remoteAddrPort *netip.AddrPort
sharedKey []byte
// Used by our state functions.
pingTimer *time.Timer
timeoutTimer *time.Timer
buf []byte
}
// ----------------------------------------------------------------------------
func newPeerSupervisor(
conf m.PeerConfig,
remoteIP byte,
w *connWriter,
table *routingTable,
) *peerSupervisor {
s := &peerSupervisor{
localIP: conf.PeerIP,
remoteIP: remoteIP,
privKey: conf.EncPrivKey,
w: w,
table: table,
packets: make(chan routingPacketWrapper, 256),
peerUpdates: make(chan *m.Peer, 1),
pingTimer: time.NewTimer(pingInterval),
timeoutTimer: time.NewTimer(timeoutInterval),
buf: make([]byte, bufferSize),
}
_, s.localPublic = netip.AddrFromSlice(conf.PublicIP)
go s.mainLoop()
return s
}
func (s *peerSupervisor) logf(msg string, args ...any) {
msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg
log.Printf(msg, args...)
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) mainLoop() {
defer panicHandler()
state := s.stateInit
for {
state = state()
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) {
if p != nil {
if p.Version == s.version {
return
}
s.logf("New peer version: %d", p.Version)
s.version = p.Version
} else {
s.version = 0
}
s.peerUpdates <- p
}
func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) {
select {
case s.packets <- w:
default:
// Drop
}
}
// ----------------------------------------------------------------------------
type stateFunc func() stateFunc
func (s *peerSupervisor) stateInit() stateFunc {
if s.peer == nil {
return s.stateDisconnected
}
addr, ok := netip.AddrFromSlice(s.peer.PublicIP)
if ok {
addrPort := netip.AddrPortFrom(addr, s.peer.Port)
s.remoteAddrPort = &addrPort
} else {
s.remoteAddrPort = nil
}
s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey)
return s.stateSelectRole()
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateDisconnected() stateFunc {
s.clearRoutingTable()
for {
select {
case <-s.packets:
// Drop
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateSelectRole() stateFunc {
s.logf("STATE: SelectRole")
s.updateRoutingTable(false)
if s.remoteAddrPort != nil {
// If both remote and local are public, one side acts as client, and one
// side as server.
if s.localPublic && s.localIP < s.peer.PeerIP {
return s.stateAccept
}
return s.stateDial
}
// We're public, remote is not => can only wait for connection
if s.localPublic {
return s.stateAccept
}
// Both non-public => need to use mediator.
return s.stateMediated
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateAccept() stateFunc {
s.logf("STATE: Accept")
for {
select {
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePing:
s.remoteAddrPort = &pkt.Addr
s.updateRoutingTable(true)
s.sendPong(pkt.TraceID)
return s.stateConnected
default:
// Still waiting for ping...
}
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateDial() stateFunc {
s.logf("STATE: Dial")
s.updateRoutingTable(false)
s.sendPing()
for {
select {
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePong:
s.updateRoutingTable(true)
return s.stateConnected
default:
// Ignore
}
case <-s.pingTimer.C:
s.sendPing()
case s.peer = <-s.peerUpdates:
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateConnected() stateFunc {
s.logf("STATE: Connected")
s.timeoutTimer.Reset(timeoutInterval)
for {
select {
case <-s.pingTimer.C:
s.sendPing()
case <-s.timeoutTimer.C:
s.logf("Timeout")
return s.stateInit
case pkt := <-s.packets:
switch pkt.Type {
case packetTypePing:
s.sendPong(pkt.TraceID)
// Server should always follow remote port.
if s.localPublic {
if pkt.Addr != *s.remoteAddrPort {
s.remoteAddrPort = &pkt.Addr
s.updateRoutingTable(true)
}
}
case packetTypePong:
s.timeoutTimer.Reset(timeoutInterval)
default:
// Drop packet.
}
case s.peer = <-s.peerUpdates:
s.logf("New peer: %v", s.peer)
return s.stateInit
}
}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) stateMediated() stateFunc {
s.logf("STATE: Mediated")
// TODO
select {}
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) clearRoutingTable() {
s.table.Set(s.remoteIP, nil)
}
func (s *peerSupervisor) updateRoutingTable(up bool) {
s.table.Set(s.remoteIP, &peer{
Up: up,
IP: s.remoteIP,
Addr: s.remoteAddrPort,
SharedKey: s.sharedKey,
})
}
// ----------------------------------------------------------------------------
func (s *peerSupervisor) sendPing() uint64 {
traceID := newTraceID()
pkt := newRoutingPacket(packetTypePing, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
s.pingTimer.Reset(pingInterval)
return traceID
}
func (s *peerSupervisor) sendPong(traceID uint64) {
pkt := newRoutingPacket(packetTypePong, traceID)
s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf))
}