diff --git a/goautossh.go b/goautossh.go index 962f639..8b60331 100644 --- a/goautossh.go +++ b/goautossh.go @@ -2,19 +2,16 @@ package main import ( "fmt" + "log" "math/rand" "net" "os" "os/exec" "strings" "time" - - "git.crumpington.com/public/golib/clog" ) -var log clog.Logger - -type Config struct { +type SSHWatcher struct { // userCmd is appended to the ssh command line. We additionally add // commands to make a forwarding loop to monitor the connection. userCmd string @@ -37,55 +34,91 @@ type Config struct { cmd *exec.Cmd } -type StateFunc func(*Config) StateFunc +func NewSSHWatcher() *SSHWatcher { + return &SSHWatcher{ + userCmd: strings.Join(os.Args[1:], " "), + connectWait: 8 * time.Second, + pingInterval: 8 * time.Second, + pingTimeout: 32 * time.Second, + retryWait: 32 * time.Second, + pingChan: make(chan byte), + } +} -func runSshCommand(conf *Config) StateFunc { - conf.pingClientPort = 32768 + rand.Intn(28233) - conf.pingServerPort = 32768 + rand.Intn(28233) +type StateFunc func() StateFunc - conf.cmdStr = "ssh " + +func (w *SSHWatcher) Run() { + fn := w.runSSHCommand() + for { + fn = fn() + } +} + +func (w *SSHWatcher) runSSHCommand() StateFunc { + w.pingClientPort = 32768 + rand.Intn(28233) + w.pingServerPort = 32768 + rand.Intn(28233) + + w.cmdStr = "ssh " + "-o ControlPersist=no -o ControlMaster=no -o GatewayPorts=yes " + "-N -L " + fmt.Sprintf("%v:localhost:%v -R %v:localhost:%v ", - conf.pingClientPort, - conf.pingClientPort, - conf.pingClientPort, - conf.pingServerPort) + - conf.userCmd + w.pingClientPort, + w.pingClientPort, + w.pingClientPort, + w.pingServerPort) + + w.userCmd - log.Msg("Running command: %v", conf.cmdStr) - conf.cmd = exec.Command("bash", "-c", conf.cmdStr) + log.Printf("Running command: %v", w.cmdStr) + w.cmd = exec.Command("bash", "-c", w.cmdStr) go func() { - output, err := conf.cmd.CombinedOutput() - log.Msg("SSH command output: %v", string(output)) + output, err := w.cmd.CombinedOutput() + log.Printf("SSH command output: %v", string(output)) if err != nil { - log.Err(err, "When executing SSH command") + log.Printf("Failed to execute command: %v", err) } }() - return startPingServer + return w.startPingServer } -func sleepRetry(conf *Config) StateFunc { - log.Msg("Sleeping before retrying...") - conf.cmd.Process.Kill() - if conf.pingConn != nil { - conf.pingConn.Close() - conf.pingConn = nil +func (w *SSHWatcher) sleepRetry() StateFunc { + log.Printf("Sleeping before retrying...") + w.cmd.Process.Kill() + if w.pingConn != nil { + w.pingConn.Close() + w.pingConn = nil } - if conf.pingListener != nil { - conf.pingListener.Close() - conf.pingListener = nil + if w.pingListener != nil { + w.pingListener.Close() + w.pingListener = nil } - time.Sleep(conf.retryWait) - return runSshCommand + time.Sleep(w.retryWait) + return w.runSSHCommand } -func runPingServer(l net.Listener, pingChan chan byte) { - conn, err := l.Accept() +func (w *SSHWatcher) startPingServer() StateFunc { + addr := fmt.Sprintf("localhost:%v", w.pingServerPort) + log.Printf("Starting ping server on: %v", addr) + + var err error + w.pingListener, err = net.Listen("tcp", addr) if err != nil { - log.Err(err, "When accepting ping connection") + log.Printf("Failed to create server listener: %v", err) + return w.sleepRetry + } + + go w.runPingServer() + + time.Sleep(w.connectWait) + + return w.startPingClient +} + +func (w *SSHWatcher) runPingServer() { + conn, err := w.pingListener.Accept() + if err != nil { + log.Printf("Failed to accept ping connection: %v", err) return } @@ -94,98 +127,69 @@ func runPingServer(l net.Listener, pingChan chan byte) { for { _, err = conn.Read(buf) if err != nil { - log.Err(err, "When reading from ping connection") + log.Printf("Failed to read from ping connection: %v", err) return } select { - case pingChan <- buf[0]: + case w.pingChan <- buf[0]: default: - log.Msg("Ping channel full. Stopping ping server.") + log.Printf("Ping channel full. Stopping ping server.") return } } } -func startPingServer(conf *Config) StateFunc { - addr := fmt.Sprintf("localhost:%v", conf.pingServerPort) - log.Msg("Starting ping server on: %v", addr) +func (w *SSHWatcher) startPingClient() StateFunc { + addr := fmt.Sprintf("localhost:%v", w.pingClientPort) + log.Printf("Starting ping client on: %v", addr) var err error - conf.pingListener, err = net.Listen("tcp", addr) + w.pingConn, err = net.DialTimeout("tcp", addr, w.pingInterval) if err != nil { - log.Err(err, "When creating server listener") - return sleepRetry + log.Printf("Failed to dial ping client port: %v", err) + return w.sleepRetry } - go runPingServer(conf.pingListener, conf.pingChan) + go w.runPingClient() - time.Sleep(conf.connectWait) - - return startPingClient + return w.pingLoop } -func runPingClient(conn net.Conn, pingTimeout, pingInterval time.Duration) { +func (w *SSHWatcher) runPingClient() { // Send pings. for { // Set timeout. - err := conn.SetWriteDeadline(time.Now().Add(pingTimeout)) + err := w.pingConn.SetWriteDeadline(time.Now().Add(w.pingTimeout)) if err != nil { - log.Err(err, "When setting ping client write deadline") + log.Printf("Failed to set ping client write deadline: %v", err) return } // Write ping data. - if _, err = conn.Write([]byte("1")); err != nil { - log.Err(err, "When writing ping data") + if _, err = w.pingConn.Write([]byte("1")); err != nil { + log.Printf("Failed to write ping data: %v", err) return } - time.Sleep(pingInterval) + time.Sleep(w.pingInterval) } } -func startPingClient(conf *Config) StateFunc { - addr := fmt.Sprintf("localhost:%v", conf.pingClientPort) - log.Msg("Starting ping client on: %v", addr) - - var err error - conf.pingConn, err = net.DialTimeout("tcp", addr, conf.pingInterval) - if err != nil { - log.Err(err, "When dialing ping client port") - return sleepRetry - } - - go runPingClient(conf.pingConn, conf.pingTimeout, conf.pingInterval) - - return pingLoop -} - -func pingLoop(conf *Config) StateFunc { +func (w *SSHWatcher) pingLoop() StateFunc { for { select { - case <-conf.pingChan: - log.Msg("Ping") - case <-time.After(conf.pingTimeout): - log.Msg("Timed out waiting for ping.") - return sleepRetry + case <-w.pingChan: + log.Printf("Ping") + case <-time.After(w.pingTimeout): + log.Printf("Timed out waiting for ping.") + return w.sleepRetry } } } func main() { rand.Seed(time.Now().UnixNano()) - conf := Config{} - conf.userCmd = strings.Join(os.Args[1:], " ") - log = clog.New("AutoSSH: " + conf.userCmd) - conf.connectWait = 8 * time.Second - conf.pingInterval = 8 * time.Second - conf.pingTimeout = 32 * time.Second - conf.retryWait = 32 * time.Second - conf.pingChan = make(chan byte) - fn := runSshCommand(&conf) - for { - fn = fn(&conf) - } + NewSSHWatcher().Run() }