diff --git a/README.md b/README.md index 47893d6..eee142b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,15 @@ # goautossh -Autossh in Go. \ No newline at end of file +Autossh in Go. + +Usage: + +``` +goautossh +``` + +For example: + +``` +goautossh ssh -N -L123:remote:123 x@y.com +``` diff --git a/goautossh.go b/goautossh.go index 8b60331..c071984 100644 --- a/goautossh.go +++ b/goautossh.go @@ -22,13 +22,13 @@ type SSHWatcher struct { pingInterval time.Duration // Time between pings. pingTimeout time.Duration // Fail timeout for ping loop. - retryWait time.Duration // + retryWait time.Duration // Time to wait between failure and reconnect. pingClientPort int pingServerPort int - pingListener net.Listener // Server - pingConn net.Conn // Client + pingListener net.Listener // Server. This is thread-safe. + pingConn net.Conn // Client. This is thread-safe. pingChan chan byte cmdStr string cmd *exec.Cmd @@ -39,9 +39,9 @@ func NewSSHWatcher() *SSHWatcher { userCmd: strings.Join(os.Args[1:], " "), connectWait: 8 * time.Second, pingInterval: 8 * time.Second, - pingTimeout: 32 * time.Second, - retryWait: 32 * time.Second, - pingChan: make(chan byte), + pingTimeout: 24 * time.Second, + retryWait: 16 * time.Second, + pingChan: make(chan byte, 1), } } @@ -108,15 +108,17 @@ func (w *SSHWatcher) startPingServer() StateFunc { return w.sleepRetry } - go w.runPingServer() + go w.runPingServer(w.pingListener) time.Sleep(w.connectWait) return w.startPingClient } -func (w *SSHWatcher) runPingServer() { - conn, err := w.pingListener.Accept() +func (w *SSHWatcher) runPingServer(listener net.Listener) { + defer listener.Close() + + conn, err := listener.Accept() if err != nil { log.Printf("Failed to accept ping connection: %v", err) return @@ -152,23 +154,25 @@ func (w *SSHWatcher) startPingClient() StateFunc { return w.sleepRetry } - go w.runPingClient() + go w.runPingClient(w.pingConn) return w.pingLoop } -func (w *SSHWatcher) runPingClient() { +func (w *SSHWatcher) runPingClient(conn net.Conn) { + defer conn.Close() + // Send pings. for { // Set timeout. - err := w.pingConn.SetWriteDeadline(time.Now().Add(w.pingTimeout)) + err := conn.SetWriteDeadline(time.Now().Add(w.pingTimeout)) if err != nil { log.Printf("Failed to set ping client write deadline: %v", err) return } // Write ping data. - if _, err = w.pingConn.Write([]byte("1")); err != nil { + if _, err = conn.Write([]byte("1")); err != nil { log.Printf("Failed to write ping data: %v", err) return }