Files
vppn/peer/init.go
2026-06-14 06:17:35 +02:00

192 lines
4.7 KiB
Go

package peer
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"os"
"time"
"golang.org/x/crypto/nacl/sign"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"vppn/m"
)
// LocalState is the persisted identity for this peer, written on first run and
// loaded on every subsequent run.
type LocalState struct {
PrivKey wgtypes.Key
SignKey [64]byte // nacl/sign Ed25519 private key
VPNIP netip.Addr
VPNNet netip.Prefix
WGPort uint16
IsRelay bool
IsPublic bool
LocalDomain string
}
// localStateJSON is the on-disk representation.
type localStateJSON struct {
PrivKey string
SignKey string
VPNIP netip.Addr
VPNNet netip.Prefix
WGPort uint16
IsRelay bool
IsPublic bool
LocalDomain string
}
// LoadOrInit loads LocalState from path, or registers with the hub and creates
// the file if it doesn't exist.
func LoadOrInit(statePath, hubURL, apiKey string) (LocalState, error) {
var state LocalState
switch err := loadJSON(statePath, &state); {
case err == nil:
return state, nil
case !os.IsNotExist(err):
// File exists but is unreadable/corrupt: surface it rather than
// silently regenerating a new identity and re-registering.
return LocalState{}, fmt.Errorf("load state: %w", err)
}
privKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return LocalState{}, fmt.Errorf("generate key: %w", err)
}
state, err = initFromHub(hubURL, apiKey, privKey)
if err != nil {
return LocalState{}, err
}
if err := storeJSON(statePath, state); err != nil {
return LocalState{}, fmt.Errorf("save state: %w", err)
}
return state, nil
}
func initFromHub(hubURL, apiKey string, privKey wgtypes.Key) (LocalState, error) {
wgPubKey := privKey.PublicKey()
signPubKey, signPrivKey, err := sign.GenerateKey(rand.Reader)
if err != nil {
return LocalState{}, fmt.Errorf("generate sign key: %w", err)
}
body, err := json.Marshal(m.PeerInitArgs{
WGPubKey: wgPubKey[:],
SignPubKey: signPubKey[:],
})
if err != nil {
return LocalState{}, fmt.Errorf("json error: %w", err)
}
req, err := http.NewRequest(http.MethodPost, hubURL+"/peer/init/", bytes.NewReader(body))
if err != nil {
return LocalState{}, err
}
req.SetBasicAuth("", apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := (&http.Client{Timeout: time.Minute}).Do(req)
if err != nil {
return LocalState{}, fmt.Errorf("hub init: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return LocalState{}, fmt.Errorf("hub init: HTTP %d", resp.StatusCode)
}
var r m.PeerInitResp
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return LocalState{}, fmt.Errorf("hub init decode: %w", err)
}
if len(r.Network) != 4 {
return LocalState{}, fmt.Errorf("hub init: invalid network %v", r.Network)
}
netAddr := netip.AddrFrom4([4]byte(r.Network))
octets := netAddr.As4()
octets[3] = r.PeerIP
vpnIP := netip.AddrFrom4(octets)
vpnNet := netip.PrefixFrom(netAddr, 24)
var self *m.Peer
for i := range r.NetworkState.Peers {
if r.NetworkState.Peers[i].PeerIP == r.PeerIP {
self = &r.NetworkState.Peers[i]
break
}
}
if self == nil {
return LocalState{}, fmt.Errorf("hub init: no peer for own IP: %d", r.PeerIP)
}
public := self.IsPublic()
return LocalState{
PrivKey: privKey,
SignKey: *signPrivKey,
VPNIP: vpnIP,
VPNNet: vpnNet,
WGPort: self.Port,
IsRelay: self.Relay && public,
IsPublic: public,
LocalDomain: r.LocalDomain,
}, nil
}
func (s LocalState) MarshalJSON() ([]byte, error) {
return json.Marshal(localStateJSON{
PrivKey: base64.StdEncoding.EncodeToString(s.PrivKey[:]),
SignKey: base64.StdEncoding.EncodeToString(s.SignKey[:]),
VPNIP: s.VPNIP,
VPNNet: s.VPNNet,
WGPort: s.WGPort,
IsRelay: s.IsRelay,
IsPublic: s.IsPublic,
LocalDomain: s.LocalDomain,
})
}
func (s *LocalState) UnmarshalJSON(data []byte) error {
var j localStateJSON
if err := json.Unmarshal(data, &j); err != nil {
return err
}
keyBytes, err := base64.StdEncoding.DecodeString(j.PrivKey)
if err != nil {
return fmt.Errorf("decode key: %w", err)
}
key, err := wgtypes.NewKey(keyBytes)
if err != nil {
return fmt.Errorf("invalid key: %w", err)
}
signKeyBytes, err := base64.StdEncoding.DecodeString(j.SignKey)
if err != nil {
return fmt.Errorf("decode sign key: %w", err)
}
if len(signKeyBytes) != 64 {
return fmt.Errorf("invalid sign key length: %d", len(signKeyBytes))
}
*s = LocalState{
PrivKey: key,
SignKey: [64]byte(signKeyBytes),
VPNIP: j.VPNIP,
VPNNet: j.VPNNet,
WGPort: j.WGPort,
IsRelay: j.IsRelay,
IsPublic: j.IsPublic,
LocalDomain: j.LocalDomain,
}
return nil
}