vppn/peer/state-clientinit.go

105 lines
2.2 KiB
Go

package peer
import (
"net/netip"
"time"
)
type stateClientInit struct {
*peerData
startedAt time.Time
traceID uint64
}
func enterStateClientInit(data *peerData) peerState {
ip, ipValid := netip.AddrFromSlice(data.peer.PublicIP)
data.staged.Up = false
data.staged.Relay = false
data.staged.Direct = ipValid
data.staged.DirectAddr = netip.AddrPortFrom(ip, data.peer.Port)
data.staged.PubSignKey = data.peer.PubSignKey
data.staged.ControlCipher = newControlCipher(data.privKey, data.peer.PubKey)
data.staged.DataCipher = newDataCipher()
data.publish(data.staged)
state := &stateClientInit{
peerData: data,
startedAt: time.Now(),
traceID: newTraceID(),
}
state.sendInit()
data.pingTimer.Reset(pingInterval)
state.logf("==> ClientInit")
return state
}
func (s *stateClientInit) logf(str string, args ...any) {
s.peerData.logf("INIT | "+str, args...)
}
func (s *stateClientInit) OnMsg(raw any) peerState {
switch msg := raw.(type) {
case peerUpdateMsg:
return initPeerState(s.peerData, msg.Peer)
case controlMsg[packetInit]:
return s.onInit(msg)
case controlMsg[packetSyn]:
s.logf("Unexpected SYN")
return s
case controlMsg[packetAck]:
s.logf("Unexpected ACK")
return s
case controlMsg[packetProbe]:
return s
case controlMsg[packetLocalDiscovery]:
return s
case pingTimerMsg:
return s.onPing()
default:
s.logf("Ignoring message: %#v", raw)
return s
}
}
func (s *stateClientInit) onInit(msg controlMsg[packetInit]) peerState {
if msg.Packet.TraceID != s.traceID {
s.logf("Invalid trace ID on INIT.")
return s
}
s.logf("Got INIT version %d.", msg.Packet.Version)
return enterStateClient(s.peerData)
}
func (s *stateClientInit) onPing() peerState {
if time.Since(s.startedAt) < timeoutInterval {
s.sendInit()
return s
}
if s.staged.Direct {
s.staged.Direct = false
s.publish(s.staged)
s.startedAt = time.Now()
s.sendInit()
s.logf("Direct connection failed. Attempting indirect connection.")
return s
}
s.logf("Timeout.")
return initPeerState(s.peerData, s.peer)
}
func (s *stateClientInit) sendInit() {
s.traceID = newTraceID()
init := packetInit{
TraceID: s.traceID,
Direct: s.staged.Direct,
Version: version,
}
s.Send(s.staged, init)
}