sym-encryption #1
							
								
								
									
										13
									
								
								node/main.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								node/main.go
									
									
									
									
									
								
							| @@ -108,11 +108,11 @@ func main(netName, listenIP string, port uint16) { | |||||||
| 	peers := remotePeers{} | 	peers := remotePeers{} | ||||||
|  |  | ||||||
| 	for i := range peers { | 	for i := range peers { | ||||||
| 		peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter) | 		peers[i] = newRemotePeer(conf, byte(i), ifWriter, connWriter, &peers) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	go newHubPoller(netName, conf, peers).Run() | 	go newHubPoller(netName, conf, peers).Run() | ||||||
| 	go readFromConn(conn, peers) | 	go readFromConn(conf.PeerIP, conn, peers) | ||||||
| 	readFromIFace(iface, peers) | 	readFromIFace(iface, peers) | ||||||
|  |  | ||||||
| } | } | ||||||
| @@ -131,7 +131,7 @@ func determinePort(confPort, portFromCommandLine uint16) uint16 { | |||||||
|  |  | ||||||
| // ---------------------------------------------------------------------------- | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| func readFromConn(conn *net.UDPConn, peers remotePeers) { | func readFromConn(localIP byte, conn *net.UDPConn, peers remotePeers) { | ||||||
|  |  | ||||||
| 	defer panicHandler() | 	defer panicHandler() | ||||||
|  |  | ||||||
| @@ -157,7 +157,12 @@ func readFromConn(conn *net.UDPConn, peers remotePeers) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		h.Parse(data) | 		h.Parse(data) | ||||||
| 		peers[h.SourceIP].HandlePacket(remoteAddr, h, data) |  | ||||||
|  | 		if h.DestIP == localIP { | ||||||
|  | 			peers[h.SourceIP].HandlePacket(remoteAddr, h, data) | ||||||
|  | 		} else { | ||||||
|  | 			peers[h.DestIP].ForwardPacket(data) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,8 @@ | |||||||
| package node | package node | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"math/rand" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -47,12 +49,15 @@ func (rp *peerSuper) Run() { | |||||||
|  |  | ||||||
| func (rp *peerSuper) stateInit() stateFunc { | func (rp *peerSuper) stateInit() stateFunc { | ||||||
| 	//rp.logf("STATE: Init") | 	//rp.logf("STATE: Init") | ||||||
|  |  | ||||||
| 	x := peerData{} | 	x := peerData{} | ||||||
| 	rp.shared.Store(&x) | 	rp.shared.Store(&x) | ||||||
|  |  | ||||||
|  | 	rp.peerData.relay = false | ||||||
| 	rp.peerData.controlCipher = nil | 	rp.peerData.controlCipher = nil | ||||||
| 	rp.peerData.dataCipher = nil | 	rp.peerData.dataCipher = nil | ||||||
| 	rp.peerData.remoteAddr = zeroAddrPort | 	rp.peerData.remoteAddr = zeroAddrPort | ||||||
|  | 	rp.peerData.relayIP = 0 | ||||||
|  |  | ||||||
| 	if rp.peer == nil { | 	if rp.peer == nil { | ||||||
| 		return rp.stateDisconnected | 		return rp.stateDisconnected | ||||||
| @@ -62,6 +67,8 @@ func (rp *peerSuper) stateInit() stateFunc { | |||||||
| 	addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) | 	addr, rp.remotePublic = netip.AddrFromSlice(rp.peer.PublicIP) | ||||||
| 	if rp.remotePublic { | 	if rp.remotePublic { | ||||||
| 		rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) | 		rp.peerData.remoteAddr = netip.AddrPortFrom(addr, rp.peer.Port) | ||||||
|  | 	} else { | ||||||
|  | 		rp.peerData.relay = false | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) | 	rp.peerData.controlCipher = newControlCipher(rp.privKey, rp.peer.EncPubKey) | ||||||
| @@ -89,8 +96,7 @@ func (rp *peerSuper) stateSelectRole() stateFunc { | |||||||
| 	rp.logf("STATE: SelectRole") | 	rp.logf("STATE: SelectRole") | ||||||
|  |  | ||||||
| 	if !rp.localPublic && !rp.remotePublic { | 	if !rp.localPublic && !rp.remotePublic { | ||||||
| 		// TODO! | 		return rp.stateSelectMediator | ||||||
| 		return rp.stateDisconnected |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if !rp.localPublic { | 	if !rp.localPublic { | ||||||
| @@ -99,12 +105,55 @@ func (rp *peerSuper) stateSelectRole() stateFunc { | |||||||
| 		return rp.stateClient | 		return rp.stateClient | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if rp.localIP < rp.peer.PeerIP { | 	if rp.localIP < rp.remoteIP { | ||||||
| 		return rp.stateClient | 		return rp.stateClient | ||||||
| 	} | 	} | ||||||
| 	return rp.stateServer | 	return rp.stateServer | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (rp *peerSuper) stateSelectMediator() stateFunc { | ||||||
|  | 	rp.logf("STATE: SelectMediator") | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		log.Printf("Selecting mediator...") | ||||||
|  | 		if ip := rp.selectMediator(); ip != 0 { | ||||||
|  | 			rp.logf("Got mediator: %d", ip) | ||||||
|  | 			rp.peerData.relayIP = ip | ||||||
|  |  | ||||||
|  | 			if rp.localIP < rp.remoteIP { | ||||||
|  | 				return rp.stateClient | ||||||
|  | 			} | ||||||
|  | 			return rp.stateServer | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		select { | ||||||
|  | 		case <-time.After(pingInterval): | ||||||
|  | 			continue | ||||||
|  | 		case rp.peer = <-rp.peerUpdates: | ||||||
|  | 			return rp.stateInit | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rp *peerSuper) selectMediator() byte { | ||||||
|  | 	possible := make([]byte, 0, 8) | ||||||
|  | 	for _, peer := range rp.peers { | ||||||
|  | 		if peer.canRelay() { | ||||||
|  | 			rp.logf("relay: %v", peer.shared.Load()) | ||||||
|  | 			possible = append(possible, peer.remoteIP) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if len(possible) == 0 { | ||||||
|  | 		return 0 | ||||||
|  | 	} | ||||||
|  | 	return possible[rand.Intn(len(possible))] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| // The remote is a server. | // The remote is a server. | ||||||
| func (rp *peerSuper) stateServer() stateFunc { | func (rp *peerSuper) stateServer() stateFunc { | ||||||
| 	rp.logf("STATE: Server") | 	rp.logf("STATE: Server") | ||||||
| @@ -112,10 +161,12 @@ func (rp *peerSuper) stateServer() stateFunc { | |||||||
| 	rp.updateShared() | 	rp.updateShared() | ||||||
|  |  | ||||||
| 	var ( | 	var ( | ||||||
| 		pingTimer = time.NewTimer(pingInterval) | 		pingTimer    = time.NewTimer(pingInterval) | ||||||
| 		ping      = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} | 		timeoutTimer = time.NewTimer(timeoutInterval) | ||||||
|  | 		ping         = pingPacket{SharedKey: ([32]byte)(rp.peerData.dataCipher.Key())} | ||||||
| 	) | 	) | ||||||
| 	defer pingTimer.Stop() | 	defer pingTimer.Stop() | ||||||
|  | 	defer timeoutTimer.Stop() | ||||||
|  |  | ||||||
| 	ping.SentAt = time.Now().UnixMilli() | 	ping.SentAt = time.Now().UnixMilli() | ||||||
| 	rp.sendControlPacket(ping) | 	rp.sendControlPacket(ping) | ||||||
| @@ -127,8 +178,18 @@ func (rp *peerSuper) stateServer() stateFunc { | |||||||
| 			rp.sendControlPacket(ping) | 			rp.sendControlPacket(ping) | ||||||
| 			pingTimer.Reset(pingInterval) | 			pingTimer.Reset(pingInterval) | ||||||
|  |  | ||||||
| 		case <-rp.controlPackets: | 		case cPkt := <-rp.controlPackets: | ||||||
| 			// Ignore | 			if _, ok := cPkt.Payload.(pongPacket); ok { | ||||||
|  | 				timeoutTimer.Reset(timeoutInterval) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 		case <-timeoutTimer.C: | ||||||
|  | 			if rp.peerData.relayIP != 0 { | ||||||
|  | 				rp.logf("Timeout (server, relay)") | ||||||
|  | 				return rp.stateSelectMediator | ||||||
|  | 			} else { | ||||||
|  | 				rp.logf("Timeout (server)") | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 		case rp.peer = <-rp.peerUpdates: | 		case rp.peer = <-rp.peerUpdates: | ||||||
| 			return rp.stateInit | 			return rp.stateInit | ||||||
| @@ -143,8 +204,12 @@ func (rp *peerSuper) stateClient() stateFunc { | |||||||
| 	rp.logf("STATE: Client") | 	rp.logf("STATE: Client") | ||||||
| 	rp.updateShared() | 	rp.updateShared() | ||||||
|  |  | ||||||
| 	// TODO: Could use timeout to set dataCipher to nil. | 	var ( | ||||||
| 	var currentKey = [32]byte{} | 		currentKey   = [32]byte{} | ||||||
|  | 		timeoutTimer = time.NewTimer(timeoutInterval) | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	defer timeoutTimer.Stop() | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| @@ -163,12 +228,22 @@ func (rp *peerSuper) stateClient() stateFunc { | |||||||
| 			if ping.SharedKey != currentKey { | 			if ping.SharedKey != currentKey { | ||||||
| 				rp.logf("Connected with new shared key") | 				rp.logf("Connected with new shared key") | ||||||
| 				currentKey = ping.SharedKey | 				currentKey = ping.SharedKey | ||||||
|  | 				rp.peerData.up = true | ||||||
| 				rp.peerData.dataCipher = newDataCipherFromKey(currentKey) | 				rp.peerData.dataCipher = newDataCipherFromKey(currentKey) | ||||||
| 				rp.updateShared() | 				rp.updateShared() | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			timeoutTimer.Reset(timeoutInterval) | ||||||
| 			rp.sendControlPacket(newPongPacket(ping.SentAt)) | 			rp.sendControlPacket(newPongPacket(ping.SentAt)) | ||||||
|  |  | ||||||
|  | 		case <-timeoutTimer.C: | ||||||
|  | 			if rp.peerData.relayIP != 0 { | ||||||
|  | 				rp.logf("Timeout (server, relay)") | ||||||
|  | 				return rp.stateSelectMediator | ||||||
|  | 			} else { | ||||||
|  | 				rp.logf("Timeout (server)") | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 		case rp.peer = <-rp.peerUpdates: | 		case rp.peer = <-rp.peerUpdates: | ||||||
| 			return rp.stateInit | 			return rp.stateInit | ||||||
| 		} | 		} | ||||||
| @@ -193,5 +268,10 @@ func (rp *peerSuper) sendControlPacket(pkt interface{ Marshal([]byte) []byte }) | |||||||
| 		DestIP:   rp.remoteIP, | 		DestIP:   rp.remoteIP, | ||||||
| 	} | 	} | ||||||
| 	buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) | 	buf = rp.peerData.controlCipher.Encrypt(h, buf, rp.encBuf) | ||||||
| 	rp.conn.WriteTo(buf, rp.peerData.remoteAddr) | 	if rp.peerData.relayIP == 0 { | ||||||
|  | 		rp.conn.WriteTo(buf, rp.peerData.remoteAddr) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rp.peers[rp.peerData.relayIP].RelayControlData(buf) | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										56
									
								
								node/peer.go
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								node/peer.go
									
									
									
									
									
								
							| @@ -12,6 +12,8 @@ import ( | |||||||
| type remotePeers [256]*remotePeer | type remotePeers [256]*remotePeer | ||||||
|  |  | ||||||
| type peerData struct { | type peerData struct { | ||||||
|  | 	up            bool | ||||||
|  | 	relay         bool | ||||||
| 	controlCipher *controlCipher | 	controlCipher *controlCipher | ||||||
| 	dataCipher    *dataCipher | 	dataCipher    *dataCipher | ||||||
| 	remoteAddr    netip.AddrPort | 	remoteAddr    netip.AddrPort | ||||||
| @@ -28,6 +30,7 @@ type remotePeer struct { | |||||||
| 	conn        *connWriter | 	conn        *connWriter | ||||||
|  |  | ||||||
| 	// Shared state. | 	// Shared state. | ||||||
|  | 	peers  *remotePeers | ||||||
| 	shared *atomic.Pointer[peerData] | 	shared *atomic.Pointer[peerData] | ||||||
|  |  | ||||||
| 	// Only used in HandlePeerUpdate. | 	// Only used in HandlePeerUpdate. | ||||||
| @@ -48,7 +51,7 @@ type remotePeer struct { | |||||||
| 	controlPackets chan controlPacket | 	controlPackets chan controlPacket | ||||||
| } | } | ||||||
|  |  | ||||||
| func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter) *remotePeer { | func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *connWriter, peers *remotePeers) *remotePeer { | ||||||
| 	rp := &remotePeer{ | 	rp := &remotePeer{ | ||||||
| 		localIP:        conf.PeerIP, | 		localIP:        conf.PeerIP, | ||||||
| 		remoteIP:       remoteIP, | 		remoteIP:       remoteIP, | ||||||
| @@ -56,6 +59,7 @@ func newRemotePeer(conf m.PeerConfig, remoteIP byte, iface *ifWriter, conn *conn | |||||||
| 		localPublic:    addrIsValid(conf.PublicIP), | 		localPublic:    addrIsValid(conf.PublicIP), | ||||||
| 		iface:          iface, | 		iface:          iface, | ||||||
| 		conn:           conn, | 		conn:           conn, | ||||||
|  | 		peers:          peers, | ||||||
| 		shared:         &atomic.Pointer[peerData]{}, | 		shared:         &atomic.Pointer[peerData]{}, | ||||||
| 		dupCheck:       newDupCheck(0), | 		dupCheck:       newDupCheck(0), | ||||||
| 		decryptBuf:     make([]byte, bufferSize), | 		decryptBuf:     make([]byte, bufferSize), | ||||||
| @@ -97,10 +101,6 @@ func (rp *remotePeer) HandlePacket(addr netip.AddrPort, h header, data []byte) { | |||||||
| 	case dataStreamID: | 	case dataStreamID: | ||||||
| 		rp.handleDataPacket(data) | 		rp.handleDataPacket(data) | ||||||
|  |  | ||||||
| 	case forwardStreamID: |  | ||||||
| 		fallthrough |  | ||||||
| 		// TODO |  | ||||||
| 		//rp.handleForwardPacket(h, data) |  | ||||||
| 	default: | 	default: | ||||||
| 		rp.logf("Unknown stream ID: %d", h.StreamID) | 		rp.logf("Unknown stream ID: %d", h.StreamID) | ||||||
| 	} | 	} | ||||||
| @@ -115,6 +115,11 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if h.DestIP != rp.localIP { | ||||||
|  | 		rp.logf("Incorrect destination IP on control packet.") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) | 	out, ok := shared.controlCipher.Decrypt(data, rp.decryptBuf) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		rp.logf("Failed to decrypt control packet.") | 		rp.logf("Failed to decrypt control packet.") | ||||||
| @@ -131,13 +136,6 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if h.DestIP != rp.localIP { |  | ||||||
| 		// TODO: Forward control packet. |  | ||||||
| 		// TODO: Probably this should be dropped. |  | ||||||
| 		// Control packets should be forwarded as data for efficiency. |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	pkt := controlPacket{ | 	pkt := controlPacket{ | ||||||
| 		SrcIP:      h.SourceIP, | 		SrcIP:      h.SourceIP, | ||||||
| 		RemoteAddr: addr, | 		RemoteAddr: addr, | ||||||
| @@ -167,6 +165,8 @@ func (rp *remotePeer) handleControlPacket(addr netip.AddrPort, h header, data [] | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
| func (rp *remotePeer) handleDataPacket(data []byte) { | func (rp *remotePeer) handleDataPacket(data []byte) { | ||||||
| 	shared := rp.shared.Load() | 	shared := rp.shared.Load() | ||||||
| 	if shared.dataCipher == nil { | 	if shared.dataCipher == nil { | ||||||
| @@ -189,6 +189,29 @@ func (rp *remotePeer) handleDataPacket(data []byte) { | |||||||
| // | // | ||||||
| // This function is called by a single thread. | // This function is called by a single thread. | ||||||
| func (rp *remotePeer) SendData(data []byte) { | func (rp *remotePeer) SendData(data []byte) { | ||||||
|  | 	rp.sendData(dataStreamID, data) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (rp *remotePeer) RelayControlData(data []byte) { | ||||||
|  | 	rp.sendData(forwardStreamID, data) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (rp *remotePeer) ForwardPacket(data []byte) { | ||||||
|  | 	shared := rp.shared.Load() | ||||||
|  | 	if shared.remoteAddr == zeroAddrPort { | ||||||
|  | 		rp.logf("Not connected (forward).") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	rp.conn.WriteTo(data, shared.remoteAddr) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (rp *remotePeer) sendData(streamID byte, data []byte) { | ||||||
| 	shared := rp.shared.Load() | 	shared := rp.shared.Load() | ||||||
| 	if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { | 	if shared.dataCipher == nil || shared.remoteAddr == zeroAddrPort { | ||||||
| 		rp.logf("Not connected (send).") | 		rp.logf("Not connected (send).") | ||||||
| @@ -196,7 +219,7 @@ func (rp *remotePeer) SendData(data []byte) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	h := header{ | 	h := header{ | ||||||
| 		StreamID: dataStreamID, | 		StreamID: streamID, | ||||||
| 		Counter:  atomic.AddUint64(&rp.counter, 1), | 		Counter:  atomic.AddUint64(&rp.counter, 1), | ||||||
| 		SourceIP: rp.localIP, | 		SourceIP: rp.localIP, | ||||||
| 		DestIP:   rp.remoteIP, | 		DestIP:   rp.remoteIP, | ||||||
| @@ -205,3 +228,10 @@ func (rp *remotePeer) SendData(data []byte) { | |||||||
| 	enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) | 	enc := shared.dataCipher.Encrypt(h, data, rp.encryptBuf) | ||||||
| 	rp.conn.WriteTo(enc, shared.remoteAddr) | 	rp.conn.WriteTo(enc, shared.remoteAddr) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (rp *remotePeer) canRelay() bool { | ||||||
|  | 	shared := rp.shared.Load() | ||||||
|  | 	return shared.relay && shared.up | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user