package main import ( "fmt" "log" "math/rand" "net" "os" "os/exec" "strings" "time" ) type SSHWatcher struct { // userCmd is appended to the ssh command line. We additionally add // commands to make a forwarding loop to monitor the connection. userCmd string // Wait after the ssh connection attempt before beginning to monitor the // connection. connectWait time.Duration pingInterval time.Duration // Time between pings. pingTimeout time.Duration // Fail timeout for ping loop. retryWait time.Duration // Time to wait between failure and reconnect. pingClientPort int pingServerPort int pingListener net.Listener // Server. This is thread-safe. pingConn net.Conn // Client. This is thread-safe. pingChan chan byte cmdStr string cmd *exec.Cmd } func NewSSHWatcher() *SSHWatcher { return &SSHWatcher{ userCmd: strings.Join(os.Args[1:], " "), connectWait: 8 * time.Second, pingInterval: 8 * time.Second, pingTimeout: 24 * time.Second, retryWait: 16 * time.Second, pingChan: make(chan byte, 1), } } type StateFunc func() StateFunc func (w *SSHWatcher) Run() { fn := w.runSSHCommand() for { fn = fn() } } func (w *SSHWatcher) runSSHCommand() StateFunc { w.pingClientPort = 32768 + rand.Intn(28233) w.pingServerPort = 32768 + rand.Intn(28233) w.cmdStr = "ssh " + "-o ControlPersist=no -o ControlMaster=no -o GatewayPorts=yes " + "-N -L " + fmt.Sprintf("%v:localhost:%v -R %v:localhost:%v ", w.pingClientPort, w.pingClientPort, w.pingClientPort, w.pingServerPort) + w.userCmd log.Printf("Running command: %v", w.cmdStr) w.cmd = exec.Command("bash", "-c", w.cmdStr) go func() { output, err := w.cmd.CombinedOutput() log.Printf("SSH command output: %v", string(output)) if err != nil { log.Printf("Failed to execute command: %v", err) } }() return w.startPingServer } func (w *SSHWatcher) sleepRetry() StateFunc { log.Printf("Sleeping before retrying...") w.cmd.Process.Kill() if w.pingConn != nil { w.pingConn.Close() w.pingConn = nil } if w.pingListener != nil { w.pingListener.Close() w.pingListener = nil } time.Sleep(w.retryWait) return w.runSSHCommand } func (w *SSHWatcher) startPingServer() StateFunc { addr := fmt.Sprintf("localhost:%v", w.pingServerPort) log.Printf("Starting ping server on: %v", addr) var err error w.pingListener, err = net.Listen("tcp", addr) if err != nil { log.Printf("Failed to create server listener: %v", err) return w.sleepRetry } go w.runPingServer(w.pingListener) time.Sleep(w.connectWait) return w.startPingClient } func (w *SSHWatcher) runPingServer(listener net.Listener) { defer listener.Close() conn, err := listener.Accept() if err != nil { log.Printf("Failed to accept ping connection: %v", err) return } buf := make([]byte, 1) for { _, err = conn.Read(buf) if err != nil { log.Printf("Failed to read from ping connection: %v", err) return } select { case w.pingChan <- buf[0]: default: log.Printf("Ping channel full. Stopping ping server.") return } } } func (w *SSHWatcher) startPingClient() StateFunc { addr := fmt.Sprintf("localhost:%v", w.pingClientPort) log.Printf("Starting ping client on: %v", addr) var err error w.pingConn, err = net.DialTimeout("tcp", addr, w.pingInterval) if err != nil { log.Printf("Failed to dial ping client port: %v", err) return w.sleepRetry } go w.runPingClient(w.pingConn) return w.pingLoop } func (w *SSHWatcher) runPingClient(conn net.Conn) { defer conn.Close() // Send pings. for { // Set timeout. err := conn.SetWriteDeadline(time.Now().Add(w.pingTimeout)) if err != nil { log.Printf("Failed to set ping client write deadline: %v", err) return } // Write ping data. if _, err = conn.Write([]byte("1")); err != nil { log.Printf("Failed to write ping data: %v", err) return } time.Sleep(w.pingInterval) } } func (w *SSHWatcher) pingLoop() StateFunc { for { select { case <-w.pingChan: log.Printf("Ping") case <-time.After(w.pingTimeout): log.Printf("Timed out waiting for ping.") return w.sleepRetry } } } func main() { rand.Seed(time.Now().UnixNano()) NewSSHWatcher().Run() }