Client-server working. No mediator.
This commit is contained in:
		
							
								
								
									
										58
									
								
								node/conn.go
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								node/conn.go
									
									
									
									
									
								
							| @@ -4,30 +4,26 @@ import ( | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"vppn/fasttime" | ||||
| ) | ||||
|  | ||||
| // TODO: | ||||
| type connRouter interface { | ||||
| 	Lookup(byte) *peer | ||||
| 	Mediator() *peer | ||||
| } | ||||
|  | ||||
| type connWriter struct { | ||||
| 	*net.UDPConn | ||||
| 	lock     sync.Mutex | ||||
| 	localIP  byte | ||||
| 	buf      []byte | ||||
| 	counters [256]uint64 | ||||
| 	lookup   func(byte) *peer | ||||
| 	routing  *routingTable | ||||
| } | ||||
|  | ||||
| func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connWriter { | ||||
| func newConnWriter(conn *net.UDPConn, localIP byte, routing *routingTable) *connWriter { | ||||
| 	w := &connWriter{ | ||||
| 		UDPConn: conn, | ||||
| 		localIP: localIP, | ||||
| 		buf:     make([]byte, bufferSize), | ||||
| 		lookup:  lookup, | ||||
| 		routing: routing, | ||||
| 	} | ||||
|  | ||||
| 	for i := range w.counters { | ||||
| @@ -37,24 +33,36 @@ func newConnWriter(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *co | ||||
| 	return w | ||||
| } | ||||
|  | ||||
| func (w *connWriter) WriteTo(remoteIP, packetType byte, data []byte) error { | ||||
| 	peer := w.lookup(remoteIP) | ||||
| func (w *connWriter) WriteTo(remoteIP, stream byte, data []byte) error { | ||||
| 	// TODO: Handle mediator. | ||||
| 	peer := w.routing.Get(remoteIP) | ||||
| 	if peer == nil || peer.Addr == nil { | ||||
| 		log.Printf("No peer: %d", remoteIP) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if stream == streamData && !peer.Up { | ||||
| 		log.Printf("Peer down: %d", remoteIP) | ||||
| 	} | ||||
| 	return w.WriteToPeer(peer, stream, data) | ||||
| } | ||||
|  | ||||
| func (w *connWriter) WriteToPeer(peer *peer, stream byte, data []byte) error { | ||||
| 	w.lock.Lock() | ||||
|  | ||||
| 	remoteIP := peer.IP | ||||
| 	h := header{ | ||||
| 		Counter:    atomic.AddUint64(&w.counters[remoteIP], 1), | ||||
| 		SourceIP:   w.localIP, | ||||
| 		ViaIP:      0, | ||||
| 		DestIP:     remoteIP, | ||||
| 		PacketType: packetType, | ||||
| 		Counter:  atomic.AddUint64(&w.counters[remoteIP], 1), | ||||
| 		SourceIP: w.localIP, | ||||
| 		ViaIP:    0, | ||||
| 		DestIP:   remoteIP, | ||||
| 		Stream:   stream, | ||||
| 	} | ||||
|  | ||||
| 	buf := encryptPacket(&h, peer.SharedKey, data, w.buf) | ||||
|  | ||||
| 	_, err := w.WriteToUDPAddrPort(buf, *peer.Addr) | ||||
| 	w.lock.Unlock() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| @@ -64,15 +72,15 @@ type connReader struct { | ||||
| 	*net.UDPConn | ||||
| 	localIP   byte | ||||
| 	dupChecks [256]*dupCheck | ||||
| 	lookup    func(byte) *peer | ||||
| 	routing   *routingTable | ||||
| 	buf       []byte | ||||
| } | ||||
|  | ||||
| func newConnReader(conn *net.UDPConn, localIP byte, lookup func(byte) *peer) *connReader { | ||||
| func newConnReader(conn *net.UDPConn, localIP byte, routing *routingTable) *connReader { | ||||
| 	r := &connReader{ | ||||
| 		UDPConn: conn, | ||||
| 		localIP: localIP, | ||||
| 		lookup:  lookup, | ||||
| 		routing: routing, | ||||
| 		buf:     make([]byte, bufferSize), | ||||
| 	} | ||||
| 	for i := range r.dupChecks { | ||||
| @@ -93,18 +101,20 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data | ||||
| 		data = buf[:n] | ||||
|  | ||||
| 		if n < headerSize { | ||||
| 			log.Printf("Dropping short packet: %d", n) | ||||
| 			continue // Packet it soo short. | ||||
| 		} | ||||
|  | ||||
| 		h.Parse(data) | ||||
|  | ||||
| 		if len(data) != headerSize+int(h.DataSize) { | ||||
| 			log.Printf("Incorrect size") | ||||
| 			continue // Packet is corrupt. | ||||
| 			log.Printf("Malformed packet: %d != %d", len(data), headerSize+int(h.DataSize)) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		peer := r.lookup(h.SourceIP) | ||||
| 		peer := r.routing.Get(h.SourceIP) | ||||
| 		if peer == nil { | ||||
| 			log.Printf("No peer...") | ||||
| 			log.Printf("No peer: %d...", h.SourceIP) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| @@ -117,7 +127,7 @@ func (r *connReader) Read(buf []byte) (remoteAddr netip.AddrPort, h header, data | ||||
| 		out, data = data, out | ||||
|  | ||||
| 		if r.dupChecks[h.SourceIP].IsDup(h.Counter) { | ||||
| 			log.Printf("Duplicate...") | ||||
| 			log.Printf("Duplicate: %d", h.Counter) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,11 @@ | ||||
| package node | ||||
|  | ||||
| import "golang.org/x/crypto/nacl/box" | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"vppn/fasttime" | ||||
|  | ||||
| 	"golang.org/x/crypto/nacl/box" | ||||
| ) | ||||
|  | ||||
| // Encrypting the packet will also set the header's DataSize field. | ||||
| func encryptPacket(h *header, sharedKey, data, out []byte) []byte { | ||||
| @@ -24,3 +29,23 @@ func computeSharedKey(peerPubKey, privKey []byte) []byte { | ||||
| 	box.Precompute(&shared, (*[32]byte)(peerPubKey), (*[32]byte)(privKey)) | ||||
| 	return shared[:] | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	traceIDLock    sync.Mutex | ||||
| 	traceIDTime    uint64 | ||||
| 	traceIDCounter uint64 | ||||
| ) | ||||
|  | ||||
| func newTraceID() (id uint64) { | ||||
| 	traceIDLock.Lock() | ||||
| 	defer traceIDLock.Unlock() | ||||
|  | ||||
| 	now := uint64(fasttime.Now()) | ||||
| 	if traceIDTime < now { | ||||
| 		traceIDTime = now | ||||
| 		traceIDCounter = 0 | ||||
| 	} | ||||
| 	traceIDCounter++ | ||||
|  | ||||
| 	return traceIDTime<<30 + traceIDCounter | ||||
| } | ||||
|   | ||||
| @@ -33,11 +33,11 @@ func TestEncryptDecryptPacket(t *testing.T) { | ||||
| 	rand.Read(original) | ||||
|  | ||||
| 	h := header{ | ||||
| 		Counter:    2893749238, | ||||
| 		SourceIP:   5, | ||||
| 		ViaIP:      8, | ||||
| 		DestIP:     12, | ||||
| 		PacketType: 32, | ||||
| 		Counter:  2893749238, | ||||
| 		SourceIP: 5, | ||||
| 		ViaIP:    8, | ||||
| 		DestIP:   12, | ||||
| 		Stream:   1, | ||||
| 	} | ||||
|  | ||||
| 	encrypted := make([]byte, bufferSize) | ||||
| @@ -62,7 +62,6 @@ func TestEncryptDecryptPacket(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| /* | ||||
| func BenchmarkEncryptPacket(b *testing.B) { | ||||
| 	_, privKey1, err := box.GenerateKey(rand.Reader) | ||||
| 	if err != nil { | ||||
| @@ -77,16 +76,24 @@ func BenchmarkEncryptPacket(b *testing.B) { | ||||
| 	sharedEncKey := [32]byte{} | ||||
| 	box.Precompute(&sharedEncKey, pubKey2, privKey1) | ||||
|  | ||||
| 	original := make([]byte, MTU) | ||||
| 	original := make([]byte, if_mtu) | ||||
| 	rand.Read(original) | ||||
|  | ||||
| 	nonce := make([]byte, NONCE_SIZE) | ||||
| 	nonce := make([]byte, headerSize) | ||||
| 	rand.Read(nonce) | ||||
|  | ||||
| 	encrypted := make([]byte, BUFFER_SIZE) | ||||
| 	encrypted := make([]byte, bufferSize) | ||||
|  | ||||
| 	h := header{ | ||||
| 		Counter:  2893749238, | ||||
| 		SourceIP: 5, | ||||
| 		ViaIP:    8, | ||||
| 		DestIP:   12, | ||||
| 		Stream:   1, | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		encrypted = encryptPacket(sharedEncKey[:], nonce, original, encrypted) | ||||
| 		encrypted = encryptPacket(&h, sharedEncKey[:], original, encrypted) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -107,18 +114,27 @@ func BenchmarkDecryptPacket(b *testing.B) { | ||||
| 	sharedDecKey := [32]byte{} | ||||
| 	box.Precompute(&sharedDecKey, pubKey1, privKey2) | ||||
|  | ||||
| 	original := make([]byte, MTU) | ||||
| 	original := make([]byte, if_mtu) | ||||
| 	rand.Read(original) | ||||
|  | ||||
| 	nonce := make([]byte, NONCE_SIZE) | ||||
| 	nonce := make([]byte, headerSize) | ||||
| 	rand.Read(nonce) | ||||
|  | ||||
| 	encrypted := make([]byte, BUFFER_SIZE) | ||||
| 	encrypted = encryptPacket(sharedEncKey[:], nonce, original, encrypted) | ||||
| 	h := header{ | ||||
| 		Counter:  2893749238, | ||||
| 		SourceIP: 5, | ||||
| 		ViaIP:    8, | ||||
| 		DestIP:   12, | ||||
| 		Stream:   1, | ||||
| 	} | ||||
|  | ||||
| 	decrypted := make([]byte, MTU) | ||||
| 	encrypted := encryptPacket(&h, sharedEncKey[:], original, make([]byte, bufferSize)) | ||||
| 	decrypted := make([]byte, bufferSize) | ||||
| 	var ok bool | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		decrypted, _ = decryptPacket(sharedDecKey[:], encrypted, decrypted) | ||||
| 		decrypted, ok = decryptPacket(sharedDecKey[:], encrypted, decrypted) | ||||
| 		if !ok { | ||||
| 			panic(ok) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| package node | ||||
|  | ||||
| import "log" | ||||
|  | ||||
| type dupCheck struct { | ||||
| 	bitSet | ||||
| 	head        int | ||||
| @@ -20,6 +22,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool { | ||||
|  | ||||
| 	// Before head => it's late, say it's a dup. | ||||
| 	if counter < dc.headCounter { | ||||
| 		log.Printf("Late: %d", counter) | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| @@ -27,6 +30,7 @@ func (dc *dupCheck) IsDup(counter uint64) bool { | ||||
| 	if counter < dc.tailCounter { | ||||
| 		index := (int(counter-dc.headCounter) + dc.head) % bitSetSize | ||||
| 		if dc.Get(index) { | ||||
| 			log.Printf("Dup: %d, %d", counter, dc.tailCounter) | ||||
| 			return true | ||||
| 		} | ||||
|  | ||||
|   | ||||
							
								
								
									
										82
									
								
								node/files.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								node/files.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| func configDir(netName string) string { | ||||
| 	d, err := os.UserHomeDir() | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to get user home directory: %v", err) | ||||
| 	} | ||||
| 	return filepath.Join(d, ".vppn", netName) | ||||
| } | ||||
|  | ||||
| func peerConfigPath(netName string) string { | ||||
| 	return filepath.Join(configDir(netName), "peer-config.json") | ||||
| } | ||||
|  | ||||
| func peerStatePath(netName string) string { | ||||
| 	return filepath.Join(configDir(netName), "peer-state.json") | ||||
| } | ||||
|  | ||||
| func storeJson(x any, outPath string) error { | ||||
| 	outDir := filepath.Dir(outPath) | ||||
| 	_ = os.MkdirAll(outDir, 0700) | ||||
|  | ||||
| 	tmpPath := outPath + ".tmp" | ||||
| 	buf, err := json.Marshal(x) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	f, err := os.Create(tmpPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if _, err := f.Write(buf); err != nil { | ||||
| 		f.Close() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := f.Sync(); err != nil { | ||||
| 		f.Close() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := f.Close(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return os.Rename(tmpPath, outPath) | ||||
| } | ||||
|  | ||||
| func storePeerConfig(netName string, pc m.PeerConfig) error { | ||||
| 	return storeJson(pc, peerConfigPath(netName)) | ||||
| } | ||||
|  | ||||
| func storeNetworkState(netName string, ps m.NetworkState) error { | ||||
| 	return storeJson(ps, peerStatePath(netName)) | ||||
| } | ||||
|  | ||||
| func loadJson(dataPath string, ptr any) error { | ||||
| 	data, err := os.ReadFile(dataPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return json.Unmarshal(data, ptr) | ||||
| } | ||||
|  | ||||
| func loadPeerConfig(netName string) (pc m.PeerConfig, err error) { | ||||
| 	return pc, loadJson(peerConfigPath(netName), &pc) | ||||
| } | ||||
|  | ||||
| func loadNetworkState(netName string) (ps m.NetworkState, err error) { | ||||
| 	return ps, loadJson(peerStatePath(netName), &ps) | ||||
| } | ||||
| @@ -2,15 +2,19 @@ package node | ||||
|  | ||||
| import "unsafe" | ||||
|  | ||||
| const headerSize = 24 | ||||
| const ( | ||||
| 	headerSize    = 24 | ||||
| 	streamData    = 1 | ||||
| 	streamRouting = 2 | ||||
| ) | ||||
|  | ||||
| type header struct { | ||||
| 	Counter    uint64 // Init with fasttime.Now() << 30 to ensure monotonic. | ||||
| 	SourceIP   byte | ||||
| 	ViaIP      byte | ||||
| 	DestIP     byte | ||||
| 	PacketType byte   // The packet type. See PACKET_* constants. | ||||
| 	DataSize   uint16 // Data size following associated data. | ||||
| 	Counter  uint64 // Init with fasttime.Now() << 30 to ensure monotonic. | ||||
| 	SourceIP byte | ||||
| 	ViaIP    byte | ||||
| 	DestIP   byte | ||||
| 	Stream   byte   // See stream* constants. | ||||
| 	DataSize uint16 // Data size following associated data. | ||||
| } | ||||
|  | ||||
| func (hdr *header) Parse(nb []byte) { | ||||
| @@ -18,7 +22,7 @@ func (hdr *header) Parse(nb []byte) { | ||||
| 	hdr.SourceIP = nb[8] | ||||
| 	hdr.ViaIP = nb[9] | ||||
| 	hdr.DestIP = nb[10] | ||||
| 	hdr.PacketType = nb[11] | ||||
| 	hdr.Stream = nb[11] | ||||
| 	hdr.DataSize = *(*uint16)(unsafe.Pointer(&nb[12])) | ||||
| } | ||||
|  | ||||
| @@ -27,6 +31,6 @@ func (hdr header) Marshal(buf []byte) { | ||||
| 	buf[8] = hdr.SourceIP | ||||
| 	buf[9] = hdr.ViaIP | ||||
| 	buf[10] = hdr.DestIP | ||||
| 	buf[11] = hdr.PacketType | ||||
| 	buf[11] = hdr.Stream | ||||
| 	*(*uint16)(unsafe.Pointer(&buf[12])) = hdr.DataSize | ||||
| } | ||||
|   | ||||
| @@ -4,12 +4,12 @@ import "testing" | ||||
|  | ||||
| func TestHeaderMarshalParse(t *testing.T) { | ||||
| 	nIn := header{ | ||||
| 		Counter:    3212, | ||||
| 		SourceIP:   34, | ||||
| 		ViaIP:      20, | ||||
| 		DestIP:     200, | ||||
| 		PacketType: 44, | ||||
| 		DataSize:   1235, | ||||
| 		Counter:  3212, | ||||
| 		SourceIP: 34, | ||||
| 		ViaIP:    20, | ||||
| 		DestIP:   200, | ||||
| 		Stream:   44, | ||||
| 		DataSize: 1235, | ||||
| 	} | ||||
|  | ||||
| 	buf := make([]byte, headerSize) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package node | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"os" | ||||
| 	"syscall" | ||||
| @@ -22,23 +23,27 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) | ||||
| 			return nil, ip, err | ||||
| 		} | ||||
|  | ||||
| 		if n < 20 { | ||||
| 			continue // Packet too short. | ||||
| 		} | ||||
|  | ||||
| 		buf = buf[:n] | ||||
| 		version = buf[0] >> 4 | ||||
|  | ||||
| 		switch version { | ||||
| 		case 4: | ||||
| 			if n < 20 { | ||||
| 				log.Printf("Short IPv4 packet: %d", len(buf)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip = buf[19] | ||||
|  | ||||
| 		case 6: | ||||
| 			if len(buf) < 40 { | ||||
| 				continue // Packet too short. | ||||
| 				log.Printf("Short IPv6 packet: %d", len(buf)) | ||||
| 				continue | ||||
| 			} | ||||
| 			ip = buf[39] | ||||
|  | ||||
| 		default: | ||||
| 			continue // Invalid version. | ||||
| 			log.Printf("Invalid IP packet version: %v", version) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		return buf, ip, nil | ||||
| @@ -47,7 +52,7 @@ func readNextPacket(iface io.ReadWriteCloser, buf []byte) ([]byte, byte, error) | ||||
|  | ||||
| const ( | ||||
| 	if_mtu       = 1200 | ||||
| 	if_queue_len = 1000 | ||||
| 	if_queue_len = 2048 | ||||
| ) | ||||
|  | ||||
| func openInterface(network []byte, localIP byte, name string) (io.ReadWriteCloser, error) { | ||||
|   | ||||
							
								
								
									
										190
									
								
								node/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								node/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,190 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"flag" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"os" | ||||
| 	"runtime/debug" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| func panicHandler() { | ||||
| 	if r := recover(); r != nil { | ||||
| 		log.Fatalf("\n %v\n\nstacktrace from panic: %s\n", r, string(debug.Stack())) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Main() { | ||||
| 	defer panicHandler() | ||||
|  | ||||
| 	var ( | ||||
| 		netName  string | ||||
| 		initURL  string | ||||
| 		listenIP string | ||||
| 		port     int | ||||
| 	) | ||||
|  | ||||
| 	flag.StringVar(&netName, "name", "", "[REQUIRED] The network name.") | ||||
| 	flag.StringVar(&initURL, "init-url", "", "Initializes peer from the hub URL.") | ||||
| 	flag.StringVar(&listenIP, "listen-ip", "", "IP address to listen on.") | ||||
| 	flag.IntVar(&port, "port", 0, "Port to listen on.") | ||||
| 	flag.Parse() | ||||
|  | ||||
| 	if netName == "" { | ||||
| 		flag.Usage() | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	if initURL != "" { | ||||
| 		mainInit(netName, initURL) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	main(netName, listenIP, uint16(port)) | ||||
| } | ||||
|  | ||||
| func mainInit(netName, initURL string) { | ||||
| 	if _, err := loadPeerConfig(netName); err == nil { | ||||
| 		log.Fatalf("Network is already initialized.") | ||||
| 	} | ||||
|  | ||||
| 	resp, err := http.Get(initURL) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to fetch data from hub: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	data, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to read response body: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	peerConfig := m.PeerConfig{} | ||||
| 	if err := json.Unmarshal(data, &peerConfig); err != nil { | ||||
| 		log.Fatalf("Failed to parse configuration: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if err := storePeerConfig(netName, peerConfig); err != nil { | ||||
| 		log.Fatalf("Failed to store configuration: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	log.Print("Initialization successful.") | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func main(netName, listenIP string, port uint16) { | ||||
| 	conf, err := loadPeerConfig(netName) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to load configuration: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	port = determinePort(conf.Port, port) | ||||
|  | ||||
| 	iface, err := openInterface(conf.Network, conf.PeerIP, netName) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to open interface: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	myAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", listenIP, port)) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to resolve UDP address: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	conn, err := net.ListenUDP("udp", myAddr) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to open UDP port: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	routing := newRoutingTable() | ||||
|  | ||||
| 	w := newConnWriter(conn, conf.PeerIP, routing) | ||||
| 	r := newConnReader(conn, conf.PeerIP, routing) | ||||
|  | ||||
| 	router := newRouter(netName, conf, routing, w) | ||||
|  | ||||
| 	go nodeConnReader(r, w, iface, router) | ||||
| 	nodeIFaceReader(w, iface, router) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func determinePort(confPort, portFromCommandLine uint16) uint16 { | ||||
| 	if portFromCommandLine != 0 { | ||||
| 		return portFromCommandLine | ||||
| 	} | ||||
| 	if confPort != 0 { | ||||
| 		return confPort | ||||
| 	} | ||||
| 	return 456 | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func nodeConnReader(r *connReader, w *connWriter, iface io.ReadWriteCloser, router *router) { | ||||
| 	defer panicHandler() | ||||
| 	var ( | ||||
| 		remoteAddr netip.AddrPort | ||||
| 		h          header | ||||
| 		buf        = make([]byte, bufferSize) | ||||
| 		data       []byte | ||||
| 		err        error | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
| 		remoteAddr, h, data, err = r.Read(buf) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("Failed to read from UDP connection: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		switch h.Stream { | ||||
|  | ||||
| 		case streamData: | ||||
| 			if _, err = iface.Write(data); err != nil { | ||||
| 				log.Printf("Malformed data from peer %d: %v", h.SourceIP, err) | ||||
| 			} | ||||
|  | ||||
| 		case streamRouting: | ||||
| 			router.HandlePacket(h.SourceIP, remoteAddr, data) | ||||
|  | ||||
| 		default: | ||||
| 			log.Printf("Dropping unknown stream: %d", h.Stream) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func nodeIFaceReader(w *connWriter, iface io.ReadWriteCloser, router *router) { | ||||
|  | ||||
| 	var ( | ||||
| 		buf      = make([]byte, bufferSize) | ||||
| 		packet   []byte | ||||
| 		remoteIP byte | ||||
| 		err      error | ||||
| 	) | ||||
|  | ||||
| 	for { | ||||
|  | ||||
| 		packet, remoteIP, err = readNextPacket(iface, buf) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("Failed to read from interface: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		if remoteIP == w.localIP { | ||||
| 			//log.Printf("Incoming packet for self: %x", packet) | ||||
| 			//iface.Write(packet) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if err := w.WriteTo(remoteIP, streamData, packet); err != nil { | ||||
| 			log.Fatalf("Failed to write to network: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										1
									
								
								node/node.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								node/node.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| package node | ||||
							
								
								
									
										29
									
								
								node/peer.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								node/peer.go
									
									
									
									
									
								
							| @@ -1,30 +1 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"net/netip" | ||||
| 	"sync/atomic" | ||||
| ) | ||||
|  | ||||
| type peer struct { | ||||
| 	IP        byte | ||||
| 	Addr      *netip.AddrPort // If we have direct connection, otherwise use mediator. | ||||
| 	SharedKey []byte | ||||
| } | ||||
|  | ||||
| type peerRepo [256]*atomic.Pointer[peer] | ||||
|  | ||||
| func newPeerRepo() peerRepo { | ||||
| 	pr := peerRepo{} | ||||
| 	for i := range pr { | ||||
| 		pr[i] = &atomic.Pointer[peer]{} | ||||
| 	} | ||||
| 	return pr | ||||
| } | ||||
|  | ||||
| func (pr peerRepo) Get(ip byte) *peer { | ||||
| 	return pr[ip].Load() | ||||
| } | ||||
|  | ||||
| func (pr *peerRepo) Set(ip byte, p *peer) { | ||||
| 	pr[ip].Store(p) | ||||
| } | ||||
|   | ||||
							
								
								
									
										1
									
								
								node/peerstate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								node/peerstate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| package node | ||||
							
								
								
									
										163
									
								
								node/router.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								node/router.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,163 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| type peer struct { | ||||
| 	Up        bool // No data will be sent to peers that are down. | ||||
| 	IP        byte | ||||
| 	Addr      *netip.AddrPort // If we have direct connection, otherwise use mediator. | ||||
| 	SharedKey []byte | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type routingTable struct { | ||||
| 	table    [256]*atomic.Pointer[peer] | ||||
| 	mediator *atomic.Pointer[peer] | ||||
| } | ||||
|  | ||||
| func newRoutingTable() *routingTable { | ||||
| 	r := routingTable{ | ||||
| 		mediator: &atomic.Pointer[peer]{}, | ||||
| 	} | ||||
|  | ||||
| 	for i := range r.table { | ||||
| 		r.table[i] = &atomic.Pointer[peer]{} | ||||
| 	} | ||||
|  | ||||
| 	return &r | ||||
| } | ||||
|  | ||||
| func (r *routingTable) Get(ip byte) *peer { | ||||
| 	return r.table[ip].Load() | ||||
| } | ||||
|  | ||||
| func (r *routingTable) Set(ip byte, p *peer) { | ||||
| 	r.table[ip].Store(p) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type router struct { | ||||
| 	netName string | ||||
| 	*routingTable | ||||
| 	peerSupers [256]*peerSupervisor | ||||
| } | ||||
|  | ||||
| func newRouter(netName string, conf m.PeerConfig, routingData *routingTable, w *connWriter) *router { | ||||
| 	r := &router{ | ||||
| 		netName:      netName, | ||||
| 		routingTable: routingData, | ||||
| 	} | ||||
|  | ||||
| 	for i := range r.peerSupers { | ||||
| 		r.peerSupers[i] = newPeerSupervisor( | ||||
| 			conf, | ||||
| 			byte(i), | ||||
| 			w, | ||||
| 			r.routingTable) | ||||
| 	} | ||||
|  | ||||
| 	// TODO: Handle Mediator | ||||
| 	go r.pollHub(conf) | ||||
|  | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (r *router) HandlePacket(sourceIP byte, remoteAddr netip.AddrPort, data []byte) { | ||||
| 	p := routingPacket{} | ||||
| 	if err := p.Parse(data); err != nil { | ||||
| 		log.Printf("Dropping malformed routing packet: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	w := routingPacketWrapper{ | ||||
| 		routingPacket: p, | ||||
| 		Addr:          remoteAddr, | ||||
| 	} | ||||
|  | ||||
| 	r.peerSupers[sourceIP].HandlePacket(w) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (r *router) pollHub(conf m.PeerConfig) { | ||||
| 	defer panicHandler() | ||||
|  | ||||
| 	u, err := url.Parse(conf.HubAddress) | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("Failed to parse hub address %s: %v", conf.HubAddress, err) | ||||
| 	} | ||||
| 	u.Path = "/peer/fetch-state/" | ||||
|  | ||||
| 	client := &http.Client{Timeout: 8 * time.Second} | ||||
|  | ||||
| 	req := &http.Request{ | ||||
| 		Method: http.MethodGet, | ||||
| 		URL:    u, | ||||
| 		Header: http.Header{}, | ||||
| 	} | ||||
| 	req.SetBasicAuth("", conf.APIKey) | ||||
|  | ||||
| 	// TODO: Before we start polling, load state from the file system. | ||||
| 	state, err := loadNetworkState(r.netName) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to load network state: %v", err) | ||||
| 		log.Printf("Polling hub...") | ||||
| 		r._pollHub(conf, client, req) | ||||
| 	} else { | ||||
| 		r.applyNetworkState(conf, state) | ||||
| 	} | ||||
|  | ||||
| 	for range time.Tick(64 * time.Second) { | ||||
| 		r._pollHub(conf, client, req) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *router) _pollHub(conf m.PeerConfig, client *http.Client, req *http.Request) { | ||||
| 	var state m.NetworkState | ||||
|  | ||||
| 	log.Printf("Fetching peer state from %s...", conf.HubAddress) | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to fetch peer state: %v", err) | ||||
| 		return | ||||
| 	} | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	_ = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to read body from hub: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if err := json.Unmarshal(body, &state); err != nil { | ||||
| 		log.Printf("Failed to unmarshal response from hub: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.applyNetworkState(conf, state) | ||||
|  | ||||
| 	if err := storeNetworkState(r.netName, state); err != nil { | ||||
| 		log.Printf("Failed to store network state: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *router) applyNetworkState(conf m.PeerConfig, state m.NetworkState) { | ||||
| 	for i := range state.Peers { | ||||
| 		if i != int(conf.PeerIP) { | ||||
| 			r.peerSupers[i].HandlePeerUpdate(state.Peers[i]) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										44
									
								
								node/routingpacket.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								node/routingpacket.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| var errMalformedPacket = errors.New("malformed packet") | ||||
|  | ||||
| const ( | ||||
| 	packetTypeInvalid = iota | ||||
|  | ||||
| 	// Used to maintain connection. | ||||
| 	packetTypePing | ||||
| 	packetTypePong | ||||
| ) | ||||
|  | ||||
| type routingPacket struct { | ||||
| 	Type    byte   // One of the packetType* constants. | ||||
| 	TraceID uint64 // For matching requests and responses. | ||||
| } | ||||
|  | ||||
| func newRoutingPacket(reqType byte, traceID uint64) routingPacket { | ||||
| 	return routingPacket{ | ||||
| 		Type:    reqType, | ||||
| 		TraceID: traceID, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (p routingPacket) Marshal(buf []byte) []byte { | ||||
| 	buf = buf[:32] // Reserve 32 bytes just in case we need to add anything. | ||||
| 	buf[0] = p.Type | ||||
| 	*(*uint64)(unsafe.Pointer(&buf[1])) = uint64(p.TraceID) | ||||
| 	return buf | ||||
| } | ||||
|  | ||||
| func (p *routingPacket) Parse(buf []byte) error { | ||||
| 	if len(buf) != 32 { | ||||
| 		return errMalformedPacket | ||||
| 	} | ||||
| 	p.Type = buf[0] | ||||
| 	p.TraceID = *(*uint64)(unsafe.Pointer(&buf[1])) | ||||
| 	return nil | ||||
| } | ||||
| @@ -1,14 +1,6 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/netip" | ||||
| 	"runtime/debug" | ||||
| ) | ||||
|  | ||||
| /* | ||||
| var ( | ||||
| 	network  = []byte{10, 1, 1, 0} | ||||
| 	serverIP = byte(1) | ||||
| @@ -30,7 +22,7 @@ func must(err error) { | ||||
| type TmpNode struct { | ||||
| 	network []byte | ||||
| 	localIP byte | ||||
| 	peers   peerRepo | ||||
| 	router  *router | ||||
| 	port    uint16 | ||||
| 	netName string | ||||
| 	iface   io.ReadWriteCloser | ||||
| @@ -46,7 +38,7 @@ func NewTmpNodeServer() *TmpNode { | ||||
| 	n := &TmpNode{ | ||||
| 		localIP: serverIP, | ||||
| 		network: network, | ||||
| 		peers:   newPeerRepo(), | ||||
| 		router:  &router{table: newPeerRepo()}, | ||||
| 		port:    port, | ||||
| 		netName: netName, | ||||
| 		pubKey:  pubKey1, | ||||
| @@ -63,10 +55,10 @@ func NewTmpNodeServer() *TmpNode { | ||||
| 	conn, err := net.ListenUDP("udp", myAddr) | ||||
| 	must(err) | ||||
|  | ||||
| 	n.w = newConnWriter(conn, n.localIP, n.peers.Get) | ||||
| 	n.r = newConnReader(conn, n.localIP, n.peers.Get) | ||||
| 	n.w = newConnWriter(conn, n.localIP, n.router) | ||||
| 	n.r = newConnReader(conn, n.localIP, n.router) | ||||
|  | ||||
| 	n.peers.Set(clientIP, &peer{ | ||||
| 	n.router.table.Set(clientIP, &peer{ | ||||
| 		IP:        clientIP, | ||||
| 		SharedKey: computeSharedKey(pubKey2, n.privKey), | ||||
| 	}) | ||||
| @@ -80,7 +72,7 @@ func NewTmpNodeClient(srvAddrStr string) *TmpNode { | ||||
| 	n := &TmpNode{ | ||||
| 		localIP: clientIP, | ||||
| 		network: network, | ||||
| 		peers:   newPeerRepo(), | ||||
| 		router:  &router{table: newPeerRepo()}, | ||||
| 		port:    port, | ||||
| 		netName: netName, | ||||
| 		pubKey:  pubKey2, | ||||
| @@ -97,13 +89,13 @@ func NewTmpNodeClient(srvAddrStr string) *TmpNode { | ||||
| 	conn, err := net.ListenUDP("udp", myAddr) | ||||
| 	must(err) | ||||
|  | ||||
| 	n.w = newConnWriter(conn, n.localIP, n.peers.Get) | ||||
| 	n.r = newConnReader(conn, n.localIP, n.peers.Get) | ||||
| 	n.w = newConnWriter(conn, n.localIP, n.router) | ||||
| 	n.r = newConnReader(conn, n.localIP, n.router) | ||||
|  | ||||
| 	serverAddr, err := netip.ParseAddrPort(fmt.Sprintf("%s:%d", srvAddrStr, port)) | ||||
| 	must(err) | ||||
|  | ||||
| 	n.peers.Set(serverIP, &peer{ | ||||
| 	n.router.table.Set(serverIP, &peer{ | ||||
| 		IP:        serverIP, | ||||
| 		Addr:      &serverAddr, | ||||
| 		SharedKey: computeSharedKey(pubKey1, n.privKey), | ||||
| @@ -129,7 +121,7 @@ func (n *TmpNode) RunServer() { | ||||
| 	log.Printf("Got remote addr: %d -> %v", h.SourceIP, remoteAddr) | ||||
| 	must(err) | ||||
|  | ||||
| 	n.peers.Set(h.SourceIP, &peer{ | ||||
| 	n.router.table.Set(h.SourceIP, &peer{ | ||||
| 		IP:        h.SourceIP, | ||||
| 		Addr:      &remoteAddr, | ||||
| 		SharedKey: computeSharedKey(pubKey2, n.privKey), | ||||
| @@ -144,7 +136,7 @@ func (n *TmpNode) RunServer() { | ||||
| func (n *TmpNode) RunClient() { | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			fmt.Printf("%v", r) | ||||
| 			fmt.Printf("%v\n", r) | ||||
| 			debug.PrintStack() | ||||
| 		} | ||||
| 	}() | ||||
| @@ -184,6 +176,10 @@ func (node *TmpNode) readFromConn() { | ||||
| 		// We assume that we're only receiving packets from one source. | ||||
|  | ||||
| 		_, err = node.iface.Write(packet) | ||||
| 		must(err) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Got error: %v", err) | ||||
| 		} | ||||
| 		//must(err) | ||||
| 	} | ||||
| } | ||||
| */ | ||||
|   | ||||
							
								
								
									
										312
									
								
								node/tmp_peerstate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										312
									
								
								node/tmp_peerstate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,312 @@ | ||||
| package node | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
| 	"vppn/m" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	connectTimeout  = 6 * time.Second | ||||
| 	pingInterval    = 6 * time.Second | ||||
| 	timeoutInterval = 20 * time.Second | ||||
| ) | ||||
|  | ||||
| type routingPacketWrapper struct { | ||||
| 	routingPacket | ||||
| 	Addr netip.AddrPort // Source. | ||||
| } | ||||
|  | ||||
| type peerSupervisor struct { | ||||
| 	// Constants: | ||||
| 	localIP     byte | ||||
| 	localPublic bool | ||||
| 	remoteIP    byte | ||||
| 	privKey     []byte | ||||
|  | ||||
| 	// Shared data: | ||||
| 	w     *connWriter | ||||
| 	table *routingTable | ||||
|  | ||||
| 	packets     chan routingPacketWrapper | ||||
| 	peerUpdates chan *m.Peer | ||||
|  | ||||
| 	// Peer-related items. | ||||
| 	version        int64 // Ony accessed in HandlePeerUpdate. | ||||
| 	peer           *m.Peer | ||||
| 	remoteAddrPort *netip.AddrPort | ||||
| 	sharedKey      []byte | ||||
|  | ||||
| 	// Used by our state functions. | ||||
| 	pingTimer    *time.Timer | ||||
| 	timeoutTimer *time.Timer | ||||
| 	buf          []byte | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func newPeerSupervisor( | ||||
| 	conf m.PeerConfig, | ||||
| 	remoteIP byte, | ||||
| 	w *connWriter, | ||||
| 	table *routingTable, | ||||
| ) *peerSupervisor { | ||||
| 	s := &peerSupervisor{ | ||||
| 		localIP:      conf.PeerIP, | ||||
| 		remoteIP:     remoteIP, | ||||
| 		privKey:      conf.EncPrivKey, | ||||
| 		w:            w, | ||||
| 		table:        table, | ||||
| 		packets:      make(chan routingPacketWrapper, 256), | ||||
| 		peerUpdates:  make(chan *m.Peer, 1), | ||||
| 		pingTimer:    time.NewTimer(pingInterval), | ||||
| 		timeoutTimer: time.NewTimer(timeoutInterval), | ||||
| 		buf:          make([]byte, bufferSize), | ||||
| 	} | ||||
|  | ||||
| 	_, s.localPublic = netip.AddrFromSlice(conf.PublicIP) | ||||
|  | ||||
| 	go s.mainLoop() | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) logf(msg string, args ...any) { | ||||
| 	msg = fmt.Sprintf("[%03d] ", s.remoteIP) + msg | ||||
| 	log.Printf(msg, args...) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) mainLoop() { | ||||
| 	defer panicHandler() | ||||
| 	state := s.stateInit | ||||
| 	for { | ||||
| 		state = state() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) HandlePeerUpdate(p *m.Peer) { | ||||
| 	if p != nil { | ||||
| 		if p.Version == s.version { | ||||
| 			return | ||||
| 		} | ||||
| 		s.logf("New peer version: %d", p.Version) | ||||
| 		s.version = p.Version | ||||
| 	} else { | ||||
| 		s.version = 0 | ||||
| 	} | ||||
|  | ||||
| 	s.peerUpdates <- p | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) HandlePacket(w routingPacketWrapper) { | ||||
| 	select { | ||||
| 	case s.packets <- w: | ||||
| 	default: | ||||
| 		// Drop | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type stateFunc func() stateFunc | ||||
|  | ||||
| func (s *peerSupervisor) stateInit() stateFunc { | ||||
| 	if s.peer == nil { | ||||
| 		return s.stateDisconnected | ||||
| 	} | ||||
|  | ||||
| 	addr, ok := netip.AddrFromSlice(s.peer.PublicIP) | ||||
| 	if ok { | ||||
| 		addrPort := netip.AddrPortFrom(addr, s.peer.Port) | ||||
| 		s.remoteAddrPort = &addrPort | ||||
| 	} else { | ||||
| 		s.remoteAddrPort = nil | ||||
| 	} | ||||
| 	s.sharedKey = computeSharedKey(s.peer.EncPubKey, s.privKey) | ||||
|  | ||||
| 	return s.stateSelectRole() | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) stateDisconnected() stateFunc { | ||||
| 	s.clearRoutingTable() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case <-s.packets: | ||||
| 			// Drop | ||||
| 		case s.peer = <-s.peerUpdates: | ||||
| 			return s.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) stateSelectRole() stateFunc { | ||||
| 	s.logf("STATE: SelectRole") | ||||
| 	s.updateRoutingTable(false) | ||||
|  | ||||
| 	if s.remoteAddrPort != nil { | ||||
| 		// If both remote and local are public, one side acts as client, and one | ||||
| 		// side as server. | ||||
| 		if s.localPublic && s.localIP < s.peer.PeerIP { | ||||
| 			return s.stateAccept | ||||
| 		} | ||||
| 		return s.stateDial | ||||
| 	} | ||||
|  | ||||
| 	// We're public, remote is not => can only wait for connection | ||||
| 	if s.localPublic { | ||||
| 		return s.stateAccept | ||||
| 	} | ||||
|  | ||||
| 	// Both non-public => need to use mediator. | ||||
| 	return s.stateMediated | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) stateAccept() stateFunc { | ||||
| 	s.logf("STATE: Accept") | ||||
|  | ||||
| 	for { | ||||
|  | ||||
| 		select { | ||||
| 		case pkt := <-s.packets: | ||||
| 			switch pkt.Type { | ||||
|  | ||||
| 			case packetTypePing: | ||||
| 				s.remoteAddrPort = &pkt.Addr | ||||
| 				s.updateRoutingTable(true) | ||||
| 				s.sendPong(pkt.TraceID) | ||||
| 				return s.stateConnected | ||||
|  | ||||
| 			default: | ||||
| 				// Still waiting for ping... | ||||
| 			} | ||||
|  | ||||
| 		case s.peer = <-s.peerUpdates: | ||||
| 			return s.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) stateDial() stateFunc { | ||||
| 	s.logf("STATE: Dial") | ||||
| 	s.updateRoutingTable(false) | ||||
|  | ||||
| 	s.sendPing() | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
| 		case pkt := <-s.packets: | ||||
|  | ||||
| 			switch pkt.Type { | ||||
|  | ||||
| 			case packetTypePong: | ||||
| 				s.updateRoutingTable(true) | ||||
| 				return s.stateConnected | ||||
|  | ||||
| 			default: | ||||
| 				// Ignore | ||||
| 			} | ||||
|  | ||||
| 		case <-s.pingTimer.C: | ||||
| 			s.sendPing() | ||||
|  | ||||
| 		case s.peer = <-s.peerUpdates: | ||||
| 			return s.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) stateConnected() stateFunc { | ||||
| 	s.logf("STATE: Connected") | ||||
|  | ||||
| 	s.timeoutTimer.Reset(timeoutInterval) | ||||
|  | ||||
| 	for { | ||||
| 		select { | ||||
|  | ||||
| 		case <-s.pingTimer.C: | ||||
| 			s.sendPing() | ||||
|  | ||||
| 		case <-s.timeoutTimer.C: | ||||
| 			s.logf("Timeout") | ||||
| 			return s.stateInit | ||||
|  | ||||
| 		case pkt := <-s.packets: | ||||
| 			switch pkt.Type { | ||||
| 			case packetTypePing: | ||||
| 				s.sendPong(pkt.TraceID) | ||||
|  | ||||
| 				// Server should always follow remote port. | ||||
| 				if s.localPublic { | ||||
| 					if pkt.Addr != *s.remoteAddrPort { | ||||
| 						s.remoteAddrPort = &pkt.Addr | ||||
| 						s.updateRoutingTable(true) | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 			case packetTypePong: | ||||
| 				s.timeoutTimer.Reset(timeoutInterval) | ||||
|  | ||||
| 			default: | ||||
| 				// Drop packet. | ||||
| 			} | ||||
|  | ||||
| 		case s.peer = <-s.peerUpdates: | ||||
| 			s.logf("New peer: %v", s.peer) | ||||
| 			return s.stateInit | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) stateMediated() stateFunc { | ||||
| 	s.logf("STATE: Mediated") | ||||
| 	// TODO | ||||
| 	select {} | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) clearRoutingTable() { | ||||
| 	s.table.Set(s.remoteIP, nil) | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) updateRoutingTable(up bool) { | ||||
| 	s.table.Set(s.remoteIP, &peer{ | ||||
| 		Up:        up, | ||||
| 		IP:        s.remoteIP, | ||||
| 		Addr:      s.remoteAddrPort, | ||||
| 		SharedKey: s.sharedKey, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (s *peerSupervisor) sendPing() uint64 { | ||||
| 	traceID := newTraceID() | ||||
| 	pkt := newRoutingPacket(packetTypePing, traceID) | ||||
| 	s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) | ||||
| 	s.pingTimer.Reset(pingInterval) | ||||
| 	return traceID | ||||
| } | ||||
|  | ||||
| func (s *peerSupervisor) sendPong(traceID uint64) { | ||||
| 	pkt := newRoutingPacket(packetTypePong, traceID) | ||||
| 	s.w.WriteTo(s.peer.PeerIP, streamRouting, pkt.Marshal(s.buf)) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user