This commit is contained in:
jdl
2024-12-13 15:49:15 +01:00
parent fdf0066fc2
commit 24517c02f1
5 changed files with 40 additions and 40 deletions

View File

@@ -14,7 +14,6 @@ type connSender struct {
streamID byte streamID byte
encrypted []byte encrypted []byte
nonceBuf []byte nonceBuf []byte
counterTS uint64
counter uint64 counter uint64
signingKey []byte signingKey []byte
} }
@@ -26,20 +25,15 @@ func newConnSender(conn *net.UDPConn, srcIP, streamID byte, signingPrivKey []byt
streamID: streamID, streamID: streamID,
encrypted: make([]byte, BUFFER_SIZE), encrypted: make([]byte, BUFFER_SIZE),
nonceBuf: make([]byte, NONCE_SIZE), nonceBuf: make([]byte, NONCE_SIZE),
counter: uint64(fasttime.Now()) << 30, // Ensure counter is always increasing.
signingKey: signingPrivKey, signingKey: signingPrivKey,
} }
} }
func (cs *connSender) send(packetType byte, packet []byte, route *route) { func (cs *connSender) send(packetType byte, packet []byte, route *route) {
now := uint64(fasttime.Now())
if cs.counterTS < now {
cs.counterTS = now
cs.counter = now << 30
}
cs.counter++ cs.counter++
nonce := Nonce{ nonce := Nonce{
Timestamp: fasttime.Now(),
Counter: cs.counter, Counter: cs.counter,
SourceIP: cs.sourceIP, SourceIP: cs.sourceIP,
ViaIP: route.ViaIP, ViaIP: route.ViaIP,
@@ -59,7 +53,6 @@ func (cs *connSender) send(packetType byte, packet []byte, route *route) {
toSend = encrypted toSend = encrypted
} }
log.Printf("Sending to %v: %+v", route.Addr, nonce)
if _, err := cs.conn.WriteToUDPAddrPort(toSend, route.Addr); err != nil { if _, err := cs.conn.WriteToUDPAddrPort(toSend, route.Addr); err != nil {
log.Fatalf("Failed to write UDP packet: %v\n%s", err, debug.Stack()) log.Fatalf("Failed to write UDP packet: %v\n%s", err, debug.Stack())
} }

View File

@@ -5,8 +5,8 @@ const (
NONCE_SIZE = 24 NONCE_SIZE = 24
KEY_SIZE = 32 KEY_SIZE = 32
SIG_SIZE = 64 SIG_SIZE = 64
MTU = 1408 MTU = 1376
BUFFER_SIZE = MTU + NONCE_SIZE + SIG_SIZE BUFFER_SIZE = 2048 // Definitely big enough.
STREAM_DATA = 0 STREAM_DATA = 0
STREAM_ROUTING = 1 // Routing queries and responses. STREAM_ROUTING = 1 // Routing queries and responses.

View File

@@ -3,6 +3,7 @@ package peer
import "unsafe" import "unsafe"
type Nonce struct { type Nonce struct {
Timestamp int64
Counter uint64 Counter uint64
SourceIP byte SourceIP byte
ViaIP byte ViaIP byte
@@ -12,23 +13,21 @@ type Nonce struct {
} }
func (nonce *Nonce) Parse(nb []byte) { func (nonce *Nonce) Parse(nb []byte) {
nonce.Counter = *(*uint64)(unsafe.Pointer(&nb[0])) nonce.Timestamp = *(*int64)(unsafe.Pointer(&nb[0]))
nonce.SourceIP = nb[8] nonce.Counter = *(*uint64)(unsafe.Pointer(&nb[8]))
nonce.ViaIP = nb[9] nonce.SourceIP = nb[16]
nonce.DestIP = nb[10] nonce.ViaIP = nb[17]
nonce.StreamID = nb[11] nonce.DestIP = nb[18]
nonce.PacketType = nb[12] nonce.StreamID = nb[19]
nonce.PacketType = nb[20]
} }
func (nonce Nonce) Marshal(buf []byte) { func (nonce Nonce) Marshal(buf []byte) {
*(*uint64)(unsafe.Pointer(&buf[0])) = nonce.Counter *(*int64)(unsafe.Pointer(&buf[0])) = nonce.Timestamp
buf[8] = nonce.SourceIP *(*uint64)(unsafe.Pointer(&buf[8])) = nonce.Counter
buf[9] = nonce.ViaIP buf[16] = nonce.SourceIP
buf[10] = nonce.DestIP buf[17] = nonce.ViaIP
buf[11] = nonce.StreamID buf[18] = nonce.DestIP
buf[12] = nonce.PacketType buf[19] = nonce.StreamID
} buf[20] = nonce.PacketType
func CounterTimestamp(counter uint64) int64 {
return int64(counter >> 30)
} }

View File

@@ -14,14 +14,15 @@ func (peer *Peer) ifReader() {
}() }()
var ( var (
sender = newConnSender(peer.conn, peer.ip, STREAM_DATA, peer.signPrivKey) sender = newConnSender(peer.conn, peer.ip, STREAM_DATA, peer.signPrivKey)
n int n int
destIP byte destIP byte
router = peer.router router = peer.router
route *route route *route
iface = peer.iface iface = peer.iface
err error err error
packet = make([]byte, BUFFER_SIZE) packet = make([]byte, BUFFER_SIZE)
version byte
) )
for { for {
@@ -36,8 +37,16 @@ func (peer *Peer) ifReader() {
} }
packet = packet[:n] packet = packet[:n]
version = packet[0] >> 4
destIP = packet[19] switch version {
case 4:
destIP = packet[19]
case 6:
destIP = packet[39]
default:
log.Printf("Dropping packet with IP version: %d", version)
}
route = router.GetRoute(destIP) route = router.GetRoute(destIP)
if route == nil || !route.Up { if route == nil || !route.Up {

View File

@@ -46,8 +46,8 @@ NEXT_PACKET:
nonce.Parse(packet[n-NONCE_SIZE:]) nonce.Parse(packet[n-NONCE_SIZE:])
// Drop after 8 seconds. // Drop after 8 seconds.
if CounterTimestamp(nonce.Counter) < fasttime.Now()-8 { if nonce.Timestamp < fasttime.Now()-8 {
log.Printf("Dropping old packet: %d", CounterTimestamp(nonce.Counter)) log.Printf("Dropping old packet: %d", nonce.Timestamp)
goto NEXT_PACKET goto NEXT_PACKET
} }
@@ -56,9 +56,8 @@ NEXT_PACKET:
goto NEXT_PACKET goto NEXT_PACKET
} }
// Check source counter.
if nonce.Counter <= counters[nonce.StreamID][nonce.SourceIP] { if nonce.Counter <= counters[nonce.StreamID][nonce.SourceIP] {
log.Printf("Dropping packet with bad counter: %+v", nonce) log.Printf("Dropping packet with bad counter: %d <= %d", nonce.Counter, counters[nonce.StreamID][nonce.SourceIP])
goto NEXT_PACKET goto NEXT_PACKET
} }