package peer import ( "bytes" "crypto/rand" "encoding/base64" "encoding/json" "fmt" "net/http" "net/netip" "os" "time" "golang.org/x/crypto/nacl/sign" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "vppn/m" ) // LocalState is the persisted identity for this peer, written on first run and // loaded on every subsequent run. type LocalState struct { PrivKey wgtypes.Key SignKey [64]byte // nacl/sign Ed25519 private key VPNIP netip.Addr VPNNet netip.Prefix WGPort uint16 IsRelay bool IsPublic bool LocalDomain string } // localStateJSON is the on-disk representation. type localStateJSON struct { PrivKey string SignKey string VPNIP netip.Addr VPNNet netip.Prefix WGPort uint16 IsRelay bool IsPublic bool LocalDomain string } // LoadOrInit loads LocalState from path, or registers with the hub and creates // the file if it doesn't exist. func LoadOrInit(statePath, hubURL, apiKey string) (LocalState, error) { var state LocalState switch err := loadJSON(statePath, &state); { case err == nil: return state, nil case !os.IsNotExist(err): // File exists but is unreadable/corrupt: surface it rather than // silently regenerating a new identity and re-registering. return LocalState{}, fmt.Errorf("load state: %w", err) } privKey, err := wgtypes.GeneratePrivateKey() if err != nil { return LocalState{}, fmt.Errorf("generate key: %w", err) } state, err = initFromHub(hubURL, apiKey, privKey) if err != nil { return LocalState{}, err } if err := storeJSON(statePath, state); err != nil { return LocalState{}, fmt.Errorf("save state: %w", err) } return state, nil } func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) { wgPubKey := privKey.PublicKey() signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader) if err != nil { return LocalState{}, fmt.Errorf("generate sign key: %w", err) } body, err := json.Marshal(m.PeerInitArgs{ WGPubKey: wgPubKey[:], SignPubKey: signPubKey[:], }) if err != nil { return LocalState{}, fmt.Errorf("json error: %w", err) } req, err := http.NewRequest(http.MethodPost, hubURL+"/peer/init/", bytes.NewReader(body)) if err != nil { return LocalState{}, err } req.SetBasicAuth("", apiKey) req.Header.Set("Content-Type", "application/json") resp, err := (&http.Client{Timeout: time.Minute}).Do(req) if err != nil { return LocalState{}, fmt.Errorf("hub init: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return LocalState{}, fmt.Errorf("hub init: HTTP %d", resp.StatusCode) } var r m.PeerInitResp if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { return LocalState{}, fmt.Errorf("hub init decode: %w", err) } if len(r.Network) != 4 { return LocalState{}, fmt.Errorf("hub init: invalid network %v", r.Network) } netAddr := netip.AddrFrom4([4]byte(r.Network)) octets := netAddr.As4() octets[3] = r.PeerIP vpnIP := netip.AddrFrom4(octets) vpnNet := netip.PrefixFrom(netAddr, 24) var self *m.Peer for i := range r.NetworkState.Peers { if r.NetworkState.Peers[i].PeerIP == r.PeerIP { self = &r.NetworkState.Peers[i] break } } if self == nil { return LocalState{}, fmt.Errorf("hub init: no peer for own IP: %d", r.PeerIP) } public := self.IsPublic() return LocalState{ PrivKey: privKey, SignKey: *signPrivKey, VPNIP: vpnIP, VPNNet: vpnNet, WGPort: self.Port, IsRelay: self.Relay && public, IsPublic: public, LocalDomain: r.LocalDomain, }, nil } func (s LocalState) MarshalJSON() ([]byte, error) { return json.Marshal(localStateJSON{ PrivKey: base64.StdEncoding.EncodeToString(s.PrivKey[:]), SignKey: base64.StdEncoding.EncodeToString(s.SignKey[:]), VPNIP: s.VPNIP, VPNNet: s.VPNNet, WGPort: s.WGPort, IsRelay: s.IsRelay, IsPublic: s.IsPublic, LocalDomain: s.LocalDomain, }) } func (s *LocalState) UnmarshalJSON(data []byte) error { var j localStateJSON if err := json.Unmarshal(data, &j); err != nil { return err } keyBytes, err := base64.StdEncoding.DecodeString(j.PrivKey) if err != nil { return fmt.Errorf("decode key: %w", err) } key, err := wgtypes.NewKey(keyBytes) if err != nil { return fmt.Errorf("invalid key: %w", err) } signKeyBytes, err := base64.StdEncoding.DecodeString(j.SignKey) if err != nil { return fmt.Errorf("decode sign key: %w", err) } if len(signKeyBytes) != 64 { return fmt.Errorf("invalid sign key length: %d", len(signKeyBytes)) } *s = LocalState{ PrivKey: key, SignKey: [64]byte(signKeyBytes), VPNIP: j.VPNIP, VPNNet: j.VPNNet, WGPort: j.WGPort, IsRelay: j.IsRelay, IsPublic: j.IsPublic, LocalDomain: j.LocalDomain, } return nil }