wip: client/server working.
This commit is contained in:
parent
5b34b3311b
commit
c7d3fe1ed8
@ -5,6 +5,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -22,6 +23,7 @@ func newConnWriter(conn *net.UDPConn) *connWriter {
|
||||
func (w *connWriter) WriteTo(packet []byte, addr netip.AddrPort) {
|
||||
w.lock.Lock()
|
||||
if _, err := w.conn.WriteToUDPAddrPort(packet, addr); err != nil {
|
||||
debug.PrintStack()
|
||||
log.Fatalf("Failed to write to UDP port: %v", err)
|
||||
}
|
||||
w.lock.Unlock()
|
||||
|
@ -10,7 +10,7 @@ const (
|
||||
controlHeaderSize = 24
|
||||
dataStreamID = 1
|
||||
dataHeaderSize = 12
|
||||
forwardStreamID = 3
|
||||
relayStreamID = 3
|
||||
)
|
||||
|
||||
type header struct {
|
||||
|
10
node/main.go
10
node/main.go
@ -114,7 +114,6 @@ func main(netName, listenIP string, port uint16) {
|
||||
go newHubPoller(netName, conf, peers).Run()
|
||||
go readFromConn(conf.PeerIP, conn, peers)
|
||||
readFromIFace(iface, peers)
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@ -157,12 +156,7 @@ func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) {
|
||||
}
|
||||
|
||||
h.Parse(data)
|
||||
|
||||
if h.DestIP == localIP {
|
||||
peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
|
||||
} else {
|
||||
peers[h.DestIP].ForwardPacket(data)
|
||||
}
|
||||
peers[h.SourceIP].HandlePacket(remoteAddr, h, data)
|
||||
}
|
||||
}
|
||||
|
||||
@ -183,6 +177,6 @@ func readFromIFace(iface io.ReadWriteCloser, peers remotePeers) {
|
||||
log.Fatalf("Failed to read from interface: %v", err)
|
||||
}
|
||||
|
||||
peers[remoteIP].SendData(packet)
|
||||
peers[remoteIP].HandleInterfacePacket(packet)
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,10 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var errMalformedPacket = errors.New("malformed packet")
|
||||
var (
|
||||
errMalformedPacket = errors.New("malformed packet")
|
||||
errUnknownPacketType = errors.New("unknown packet type")
|
||||
)
|
||||
|
||||
const (
|
||||
packetTypePing = iota + 1
|
||||
@ -22,6 +25,18 @@ type controlPacket struct {
|
||||
Payload any
|
||||
}
|
||||
|
||||
func (p *controlPacket) ParsePayload(buf []byte) (err error) {
|
||||
switch buf[0] {
|
||||
case packetTypePing:
|
||||
p.Payload, err = parsePingPacket(buf)
|
||||
case packetTypePong:
|
||||
p.Payload, err = parsePongPacket(buf)
|
||||
default:
|
||||
return errUnknownPacketType
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// A pingPacket is sent from a node acting as a client, to a node acting
|
||||
@ -32,9 +47,9 @@ type pingPacket struct {
|
||||
SharedKey [32]byte
|
||||
}
|
||||
|
||||
func newPingPacket(sharedKey []byte) (pp pingPacket) {
|
||||
func newPingPacket(sharedKey [32]byte) (pp pingPacket) {
|
||||
pp.SentAt = time.Now().UnixMilli()
|
||||
copy(pp.SharedKey[:], sharedKey)
|
||||
copy(pp.SharedKey[:], sharedKey[:])
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ func TestPacketPing(t *testing.T) {
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
p := newPingPacket(sharedKey)
|
||||
p := newPingPacket([32]byte(sharedKey))
|
||||
out := p.Marshal(buf)
|
||||
|
||||
p2, err := parsePingPacket(out)
|
||||
|
214
node/peer-states.go
Normal file
214
node/peer-states.go
Normal file
@ -0,0 +1,214 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"vppn/m"
|
||||
)
|
||||
|
||||
type peerState interface {
|
||||
Name() string
|
||||
OnPeerUpdate(*m.Peer) peerState
|
||||
OnPing(netip.AddrPort, pingPacket) peerState
|
||||
OnPong(netip.AddrPort, pongPacket) peerState
|
||||
OnPingTimer() peerState
|
||||
OnTimeoutTimer() peerState
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type stateBase struct {
|
||||
// The purpose of this state machine is to manage this published data.
|
||||
published *atomic.Pointer[peerData]
|
||||
|
||||
// The other remote peers.
|
||||
peers *remotePeers
|
||||
|
||||
// Immutable data.
|
||||
localIP byte
|
||||
localPub bool
|
||||
remoteIP byte
|
||||
privKey []byte
|
||||
conn *connWriter
|
||||
|
||||
// For sending to peer.
|
||||
counter *uint64
|
||||
|
||||
// Mutable peer data.
|
||||
peer *m.Peer
|
||||
remotePub bool
|
||||
data peerData // Local copy of shared data. See publish().
|
||||
|
||||
// Timers
|
||||
pingTimer *time.Timer
|
||||
timeoutTimer *time.Timer
|
||||
|
||||
buf []byte
|
||||
encBuf []byte
|
||||
}
|
||||
|
||||
func (sb *stateBase) Name() string { return "idle" }
|
||||
|
||||
func (s *stateBase) OnPeerUpdate(peer *m.Peer) peerState {
|
||||
// Both nil: no change.
|
||||
if peer == nil && s.peer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// No change.
|
||||
if peer != nil && s.peer != nil && s.peer.Version == peer.Version {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.peer = peer
|
||||
|
||||
s.data = peerData{}
|
||||
s.data.controlCipher = newControlCipher(s.privKey, peer.EncPubKey)
|
||||
|
||||
ip, isValid := netip.AddrFromSlice(peer.PublicIP)
|
||||
if isValid {
|
||||
s.remotePub = true
|
||||
s.data.remoteAddr = netip.AddrPortFrom(ip, peer.Port)
|
||||
s.data.relay = peer.Mediator
|
||||
|
||||
if s.localPub && s.localIP < s.remoteIP {
|
||||
return newStateServer(s)
|
||||
}
|
||||
return newStateClient(s)
|
||||
}
|
||||
|
||||
if s.localPub {
|
||||
return newStateServer(s)
|
||||
}
|
||||
|
||||
// TODO: return newStateMediated(a/b)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stateBase) OnPing(rAddr netip.AddrPort, p pingPacket) peerState { return nil }
|
||||
func (s *stateBase) OnPong(rAddr netip.AddrPort, p pongPacket) peerState { return nil }
|
||||
func (s *stateBase) OnPingTimer() peerState { return nil }
|
||||
func (s *stateBase) OnTimeoutTimer() peerState { return nil }
|
||||
|
||||
// Helpers.
|
||||
|
||||
func (s *stateBase) resetPingTimer() { s.pingTimer.Reset(pingInterval) }
|
||||
func (s *stateBase) resetTimeoutTimer() { s.timeoutTimer.Reset(timeoutInterval) }
|
||||
func (s *stateBase) stopPingTimer() { s.pingTimer.Stop() }
|
||||
func (s *stateBase) stopTimeoutTimer() { s.timeoutTimer.Stop() }
|
||||
|
||||
func (s *stateBase) logf(msg string, args ...any) {
|
||||
log.Printf(fmt.Sprintf("[%03d] ", s.remoteIP)+msg, args...)
|
||||
}
|
||||
|
||||
func (s *stateBase) publish() {
|
||||
data := s.data
|
||||
s.published.Store(&data)
|
||||
}
|
||||
|
||||
func (s *stateBase) sendPing(sharedKey [32]byte) {
|
||||
s.sendControlPacket(newPingPacket(sharedKey))
|
||||
}
|
||||
|
||||
func (s *stateBase) sendPong(ping pingPacket) {
|
||||
s.sendControlPacket(newPongPacket(ping.SentAt))
|
||||
}
|
||||
|
||||
func (s *stateBase) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
|
||||
buf := pkt.Marshal(s.buf)
|
||||
h := header{
|
||||
StreamID: controlStreamID,
|
||||
Counter: atomic.AddUint64(s.counter, 1),
|
||||
SourceIP: s.localIP,
|
||||
DestIP: s.remoteIP,
|
||||
}
|
||||
|
||||
buf = s.data.controlCipher.Encrypt(h, buf, s.encBuf)
|
||||
if s.data.relayIP == 0 {
|
||||
s.conn.WriteTo(buf, s.data.remoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Relay!
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type stateClient struct {
|
||||
sharedKey [32]byte
|
||||
*stateBase
|
||||
}
|
||||
|
||||
func newStateClient(b *stateBase) peerState {
|
||||
s := &stateClient{stateBase: b}
|
||||
s.publish()
|
||||
|
||||
s.data.dataCipher = newDataCipher()
|
||||
s.sharedKey = s.data.dataCipher.Key()
|
||||
|
||||
s.sendPing(s.sharedKey)
|
||||
s.resetPingTimer()
|
||||
s.resetTimeoutTimer()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *stateClient) Name() string { return "client" }
|
||||
|
||||
func (s *stateClient) OnPong(addr netip.AddrPort, p pongPacket) peerState {
|
||||
if !s.data.up {
|
||||
s.data.up = true
|
||||
s.publish()
|
||||
}
|
||||
s.resetTimeoutTimer()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stateClient) OnPingTimer() peerState {
|
||||
s.sendPing(s.sharedKey)
|
||||
s.resetPingTimer()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stateClient) OnTimeoutTimer() peerState {
|
||||
s.data.up = false
|
||||
s.publish()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type stateServer struct {
|
||||
*stateBase
|
||||
}
|
||||
|
||||
func newStateServer(b *stateBase) peerState {
|
||||
s := &stateServer{b}
|
||||
s.publish()
|
||||
s.stopPingTimer()
|
||||
s.stopTimeoutTimer()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *stateServer) Name() string { return "server" }
|
||||
|
||||
func (s *stateServer) OnPing(addr netip.AddrPort, p pingPacket) peerState {
|
||||
if addr != s.data.remoteAddr {
|
||||
s.logf("Got new peer address: %v", addr)
|
||||
s.data.remoteAddr = addr
|
||||
s.data.up = true
|
||||
s.publish()
|
||||
}
|
||||
|
||||
if s.data.dataCipher == nil || p.SharedKey != s.data.dataCipher.Key() {
|
||||
s.logf("Got new shared key.")
|
||||
s.data.dataCipher = newDataCipherFromKey(p.SharedKey)
|
||||
s.publish()
|
||||
}
|
||||
|
||||
s.sendPong(p)
|
||||
return nil
|
||||
}
|
@ -1,10 +1,6 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"vppn/m"
|
||||
)
|
||||
@ -15,263 +11,64 @@ const (
|
||||
timeoutInterval = 20 * time.Second
|
||||
)
|
||||
|
||||
type stateFunc func() stateFunc
|
||||
|
||||
type peerSuper struct {
|
||||
*remotePeer
|
||||
|
||||
peer *m.Peer
|
||||
remotePublic bool
|
||||
peerData peerData
|
||||
|
||||
pktBuf []byte
|
||||
encBuf []byte
|
||||
}
|
||||
|
||||
func newPeerSuper(rp *remotePeer) *peerSuper {
|
||||
return &peerSuper{
|
||||
remotePeer: rp,
|
||||
peer: nil,
|
||||
pktBuf: make([]byte, bufferSize),
|
||||
encBuf: make([]byte, bufferSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *peerSuper) Run() {
|
||||
func (rp *remotePeer) supervise(
|
||||
conf m.PeerConfig,
|
||||
remoteIP byte,
|
||||
conn *connWriter,
|
||||
peers *remotePeers,
|
||||
) {
|
||||
defer panicHandler()
|
||||
state := rp.stateInit
|
||||
for {
|
||||
state = state()
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *peerSuper) stateInit() stateFunc {
|
||||
//rp.logf("STATE: Init")
|
||||
|
||||
x := peerData{}
|
||||
rp.shared.Store(&x)
|
||||
|
||||
rp.peerData.relay = false
|
||||
rp.peerData.controlCipher = nil
|
||||
rp.peerData.dataCipher = nil
|
||||
rp.peerData.remoteAddr = zeroAddrPort
|
||||
rp.peerData.relayIP = 0
|
||||
|
||||
if rp.peer == nil {
|
||||
return rp.stateDisconnected
|
||||
base := &stateBase{
|
||||
published: rp.published,
|
||||
peers: rp.peers,
|
||||
localIP: rp.localIP,
|
||||
remoteIP: rp.remoteIP,
|
||||
privKey: conf.EncPrivKey,
|
||||
localPub: addrIsValid(conf.PublicIP),
|
||||
conn: rp.conn,
|
||||
counter: &rp.counter,
|
||||
pingTimer: time.NewTimer(time.Second),
|
||||
timeoutTimer: time.NewTimer(time.Second),
|
||||
buf: make([]byte, bufferSize),
|
||||
encBuf: make([]byte, bufferSize),
|
||||
}
|
||||
|
||||
var addr netip.Addr
|
||||
addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP)
|
||||
if rp.remotePublic {
|
||||
rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port)
|
||||
} else {
|
||||
rp.peerData.relay = false
|
||||
}
|
||||
|
||||
rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey)
|
||||
|
||||
return rp.stateSelectRole()
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *peerSuper) stateDisconnected() stateFunc {
|
||||
//rp.logf("STATE: Disconnected")
|
||||
for {
|
||||
select {
|
||||
case <-rp.controlPackets:
|
||||
// Drop
|
||||
case rp.peer = <-rp.peerUpdates:
|
||||
return rp.stateInit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *peerSuper) stateSelectRole() stateFunc {
|
||||
rp.logf("STATE: SelectRole")
|
||||
|
||||
if !rp.localPublic && !rp.remotePublic {
|
||||
return rp.stateSelectMediator
|
||||
}
|
||||
|
||||
if !rp.localPublic {
|
||||
return rp.stateServer
|
||||
} else if !rp.remotePublic {
|
||||
return rp.stateClient
|
||||
}
|
||||
|
||||
if rp.localIP < rp.remoteIP {
|
||||
return rp.stateClient
|
||||
}
|
||||
return rp.stateServer
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *peerSuper) stateSelectMediator() stateFunc {
|
||||
rp.logf("STATE: SelectMediator")
|
||||
|
||||
for {
|
||||
log.Printf("Selecting mediator...")
|
||||
if ip := rp.selectMediator(); ip != 0 {
|
||||
rp.logf("Got mediator: %d", ip)
|
||||
rp.peerData.relayIP = ip
|
||||
|
||||
if rp.localIP < rp.remoteIP {
|
||||
return rp.stateClient
|
||||
}
|
||||
return rp.stateServer
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(pingInterval):
|
||||
continue
|
||||
case rp.peer = <-rp.peerUpdates:
|
||||
return rp.stateInit
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (rp *peerSuper) selectMediator() byte {
|
||||
possible := make([]byte, 0, 8)
|
||||
for _, peer := range rp.peers {
|
||||
if peer.canRelay() {
|
||||
rp.logf("relay: %v", peer.shared.Load())
|
||||
possible = append(possible, peer.remoteIP)
|
||||
}
|
||||
}
|
||||
if len(possible) == 0 {
|
||||
return 0
|
||||
}
|
||||
return possible[rand.Intn(len(possible))]
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// The remote is a server.
|
||||
func (rp *peerSuper) stateServer() stateFunc {
|
||||
rp.logf("STATE: Server")
|
||||
rp.peerData.dataCipher = newDataCipher()
|
||||
rp.updateShared()
|
||||
base.pingTimer.Stop()
|
||||
base.timeoutTimer.Stop()
|
||||
|
||||
var (
|
||||
pingTimer = time.NewTimer(pingInterval)
|
||||
timeoutTimer = time.NewTimer(timeoutInterval)
|
||||
ping = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())}
|
||||
curState peerState = base
|
||||
nextState peerState
|
||||
)
|
||||
defer pingTimer.Stop()
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
ping.SentAt = time.Now().UnixMilli()
|
||||
rp.sendControlPacket(ping)
|
||||
|
||||
for {
|
||||
nextState = nil
|
||||
|
||||
select {
|
||||
case <-pingTimer.C:
|
||||
ping.SentAt = time.Now().UnixMilli()
|
||||
rp.sendControlPacket(ping)
|
||||
pingTimer.Reset(pingInterval)
|
||||
case peer := <-rp.peerUpdates:
|
||||
nextState = curState.OnPeerUpdate(peer)
|
||||
|
||||
case cPkt := <-rp.controlPackets:
|
||||
if _, ok := cPkt.Payload.(pongPacket); ok {
|
||||
timeoutTimer.Reset(timeoutInterval)
|
||||
case pkt := <-rp.controlPackets:
|
||||
switch p := pkt.Payload.(type) {
|
||||
case pingPacket:
|
||||
nextState = curState.OnPing(pkt.RemoteAddr, p)
|
||||
case pongPacket:
|
||||
nextState = curState.OnPong(pkt.RemoteAddr, p)
|
||||
default:
|
||||
// Unknown packet type.
|
||||
}
|
||||
|
||||
case <-timeoutTimer.C:
|
||||
if rp.peerData.relayIP != 0 {
|
||||
rp.logf("Timeout (server, relay)")
|
||||
return rp.stateSelectMediator
|
||||
} else {
|
||||
rp.logf("Timeout (server)")
|
||||
}
|
||||
case <-base.pingTimer.C:
|
||||
nextState = curState.OnPingTimer()
|
||||
|
||||
case rp.peer = <-rp.peerUpdates:
|
||||
return rp.stateInit
|
||||
case <-base.timeoutTimer.C:
|
||||
nextState = curState.OnTimeoutTimer()
|
||||
}
|
||||
|
||||
if nextState != nil {
|
||||
rp.logf("%s --> %s", curState.Name(), nextState.Name())
|
||||
curState = nextState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// The remote is a client.
|
||||
func (rp *peerSuper) stateClient() stateFunc {
|
||||
rp.logf("STATE: Client")
|
||||
rp.updateShared()
|
||||
|
||||
var (
|
||||
currentKey = [32]byte{}
|
||||
timeoutTimer = time.NewTimer(timeoutInterval)
|
||||
)
|
||||
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case cPkt := <-rp.controlPackets:
|
||||
if cPkt.RemoteAddr != rp.peerData.remoteAddr {
|
||||
rp.peerData.remoteAddr = cPkt.RemoteAddr
|
||||
rp.logf("Got new remote address: %v", cPkt.RemoteAddr)
|
||||
rp.updateShared()
|
||||
}
|
||||
|
||||
ping, ok := cPkt.Payload.(pingPacket)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if ping.SharedKey != currentKey {
|
||||
rp.logf("Connected with new shared key")
|
||||
currentKey = ping.SharedKey
|
||||
rp.peerData.up = true
|
||||
rp.peerData.dataCipher = newDataCipherFromKey(currentKey)
|
||||
rp.updateShared()
|
||||
}
|
||||
|
||||
timeoutTimer.Reset(timeoutInterval)
|
||||
rp.sendControlPacket(newPongPacket(ping.SentAt))
|
||||
|
||||
case <-timeoutTimer.C:
|
||||
if rp.peerData.relayIP != 0 {
|
||||
rp.logf("Timeout (server, relay)")
|
||||
return rp.stateSelectMediator
|
||||
} else {
|
||||
rp.logf("Timeout (server)")
|
||||
}
|
||||
|
||||
case rp.peer = <-rp.peerUpdates:
|
||||
return rp.stateInit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *peerSuper) updateShared() {
|
||||
data := rp.peerData
|
||||
rp.shared.Store(&data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) {
|
||||
buf := pkt.Marshal(rp.pktBuf)
|
||||
h := header{
|
||||
StreamID: controlStreamID,
|
||||
Counter: atomic.AddUint64(&rp.counter, 1),
|
||||
SourceIP: rp.localIP,
|
||||
DestIP: rp.remoteIP,
|
||||
}
|
||||
buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf)
|
||||
if rp.peerData.relayIP == 0 {
|
||||
rp.conn.WriteTo(buf, rp.peerData.remoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
rp.peers[rp.peerData.relayIP].RelayControlData(buf)
|
||||
}
|
||||
|
129
node/peer.go
129
node/peer.go
@ -22,19 +22,14 @@ type peerData struct {
|
||||
|
||||
type remotePeer struct {
|
||||
// Immutable data.
|
||||
localIP byte
|
||||
remoteIP byte
|
||||
privKey []byte
|
||||
localPublic bool // True if local node is public.
|
||||
iface *ifWriter
|
||||
conn *connWriter
|
||||
localIP byte
|
||||
remoteIP byte
|
||||
iface *ifWriter
|
||||
conn *connWriter
|
||||
|
||||
// Shared state.
|
||||
peers *remotePeers
|
||||
shared *atomic.Pointer[peerData]
|
||||
|
||||
// Only used in HandlePeerUpdate.
|
||||
peerVersion int64
|
||||
peers *remotePeers
|
||||
published *atomic.Pointer[peerData]
|
||||
|
||||
// Only used in HandlePacket / Not synchronized.
|
||||
dupCheck *dupCheck
|
||||
@ -55,12 +50,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
|
||||
rp := &remotePeer{
|
||||
localIP: conf.PeerIP,
|
||||
remoteIP: remoteIP,
|
||||
privKey: conf.EncPrivKey,
|
||||
localPublic: addrIsValid(conf.PublicIP),
|
||||
iface: iface,
|
||||
conn: conn,
|
||||
peers: peers,
|
||||
shared: &atomic.Pointer[peerData]{},
|
||||
published: &atomic.Pointer[peerData]{},
|
||||
dupCheck: newDupCheck(0),
|
||||
decryptBuf: make([]byte, bufferSize),
|
||||
encryptBuf: make([]byte, bufferSize),
|
||||
@ -70,10 +63,10 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn
|
||||
}
|
||||
|
||||
pd := peerData{}
|
||||
rp.shared.Store(&pd)
|
||||
|
||||
go newPeerSuper(rp).Run()
|
||||
rp.published.Store(&pd)
|
||||
|
||||
//go newPeerSuper(rp).Run()
|
||||
go rp.supervise(conf, remoteIP, conn, peers)
|
||||
return rp
|
||||
}
|
||||
|
||||
@ -82,10 +75,7 @@ func (rp *remotePeer) logf(msg string, args ...any) {
|
||||
}
|
||||
|
||||
func (rp *remotePeer) HandlePeerUpdate(peer *m.Peer) {
|
||||
if peer != nil && peer.Version != rp.peerVersion {
|
||||
rp.peerUpdates <- peer
|
||||
rp.peerVersion = peer.Version
|
||||
}
|
||||
rp.peerUpdates <- peer
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@ -101,6 +91,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
|
||||
case dataStreamID:
|
||||
rp.handleDataPacket(data)
|
||||
|
||||
case relayStreamID:
|
||||
rp.handleRelayPacket(h, data)
|
||||
|
||||
default:
|
||||
rp.logf("Unknown stream ID: %d", h.StreamID)
|
||||
}
|
||||
@ -109,8 +102,9 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) {
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []byte) {
|
||||
shared := rp.shared.Load()
|
||||
shared := rp.published.Load()
|
||||
if shared.controlCipher == nil {
|
||||
log.Printf("Shared: %+v", *shared)
|
||||
rp.logf("Not connected (control).")
|
||||
return
|
||||
}
|
||||
@ -141,19 +135,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
|
||||
RemoteAddr: addr,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
switch out[0] {
|
||||
case packetTypePing:
|
||||
pkt.Payload, err = parsePingPacket(out)
|
||||
case packetTypePong:
|
||||
pkt.Payload, err = parsePongPacket(out)
|
||||
default:
|
||||
rp.logf("Unknown control packet type: %d", out[0])
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err := pkt.ParsePayload(out); err != nil {
|
||||
rp.logf("Failed to parse control packet: %v", err)
|
||||
return
|
||||
}
|
||||
@ -168,7 +150,7 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data []
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) handleDataPacket(data []byte) {
|
||||
shared := rp.shared.Load()
|
||||
shared := rp.published.Load()
|
||||
if shared.dataCipher == nil {
|
||||
rp.logf("Not connected (recv).")
|
||||
return
|
||||
@ -185,34 +167,65 @@ func (rp *remotePeer) handleDataPacket(data []byte) {
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) handleRelayPacket(h header, data []byte) {
|
||||
shared := rp.published.Load()
|
||||
if shared.dataCipher == nil {
|
||||
rp.logf("Not connected (recv).")
|
||||
return
|
||||
}
|
||||
|
||||
dec, ok := shared.dataCipher.Decrypt(data, rp.decryptBuf)
|
||||
if !ok {
|
||||
rp.logf("Failed to decrypt data packet.")
|
||||
return
|
||||
}
|
||||
|
||||
rp.peers[h.DestIP].sendDirect(dec)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// SendData sends data coming from the interface going to the network.
|
||||
//
|
||||
// This function is called by a single thread.
|
||||
func (rp *remotePeer) SendData(data []byte) {
|
||||
rp.sendData(dataStreamID, data)
|
||||
rp.sendData(dataStreamID, rp.remoteIP, data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
func (rp *remotePeer) HandleInterfacePacket(data []byte) {
|
||||
shared := rp.published.Load()
|
||||
|
||||
func (rp *remotePeer) RelayControlData(data []byte) {
|
||||
rp.sendData(forwardStreamID, data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) ForwardPacket(data []byte) {
|
||||
shared := rp.shared.Load()
|
||||
if shared.remoteAddr == zeroAddrPort {
|
||||
rp.logf("Not connected (forward).")
|
||||
if shared.dataCipher == nil {
|
||||
rp.logf("Not connected (handle interface).")
|
||||
return
|
||||
}
|
||||
rp.conn.WriteTo(data, shared.remoteAddr)
|
||||
|
||||
h := header{
|
||||
StreamID: dataStreamID,
|
||||
Counter: atomic.AddUint64(&rp.counter, 1),
|
||||
SourceIP: rp.localIP,
|
||||
DestIP: rp.remoteIP,
|
||||
}
|
||||
|
||||
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
|
||||
|
||||
if shared.relayIP != 0 {
|
||||
rp.peers[shared.relayIP].RelayData(shared.relayIP, enc)
|
||||
} else {
|
||||
rp.SendData(data)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) sendData(streamID byte, data []byte) {
|
||||
shared := rp.shared.Load()
|
||||
func (rp *remotePeer) RelayData(destIP byte, data []byte) {
|
||||
rp.sendData(relayStreamID, destIP, data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) sendData(streamID byte, destIP byte, data []byte) {
|
||||
shared := rp.published.Load()
|
||||
if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort {
|
||||
rp.logf("Not connected (send).")
|
||||
return
|
||||
@ -222,16 +235,18 @@ func (rp *remotePeer) sendData(streamID byte, data []byte) {
|
||||
StreamID: streamID,
|
||||
Counter: atomic.AddUint64(&rp.counter, 1),
|
||||
SourceIP: rp.localIP,
|
||||
DestIP: rp.remoteIP,
|
||||
DestIP: destIP,
|
||||
}
|
||||
|
||||
enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf)
|
||||
rp.conn.WriteTo(enc, shared.remoteAddr)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (rp *remotePeer) canRelay() bool {
|
||||
shared := rp.shared.Load()
|
||||
return shared.relay && shared.up
|
||||
func (rp *remotePeer) sendDirect(data []byte) {
|
||||
shared := rp.published.Load()
|
||||
if shared.remoteAddr == zeroAddrPort {
|
||||
rp.logf("Not connected (send).")
|
||||
return
|
||||
}
|
||||
rp.conn.WriteTo(data, shared.remoteAddr)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user