2017-03-10 21:19:03 +00:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2017-11-25 16:13:13 +00:00
|
|
|
"log"
|
2017-03-10 21:19:03 +00:00
|
|
|
"math/rand"
|
|
|
|
"net"
|
|
|
|
"os"
|
|
|
|
"os/exec"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
type SSHWatcher struct {
|
2017-03-10 21:19:03 +00:00
|
|
|
// userCmd is appended to the ssh command line. We additionally add
|
|
|
|
// commands to make a forwarding loop to monitor the connection.
|
|
|
|
userCmd string
|
|
|
|
|
|
|
|
// Wait after the ssh connection attempt before beginning to monitor the
|
|
|
|
// connection.
|
|
|
|
connectWait time.Duration
|
|
|
|
|
|
|
|
pingInterval time.Duration // Time between pings.
|
|
|
|
pingTimeout time.Duration // Fail timeout for ping loop.
|
|
|
|
retryWait time.Duration //
|
|
|
|
|
|
|
|
pingClientPort int
|
|
|
|
pingServerPort int
|
|
|
|
|
|
|
|
pingListener net.Listener // Server
|
|
|
|
pingConn net.Conn // Client
|
|
|
|
pingChan chan byte
|
|
|
|
cmdStr string
|
|
|
|
cmd *exec.Cmd
|
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func NewSSHWatcher() *SSHWatcher {
|
|
|
|
return &SSHWatcher{
|
|
|
|
userCmd: strings.Join(os.Args[1:], " "),
|
|
|
|
connectWait: 8 * time.Second,
|
|
|
|
pingInterval: 8 * time.Second,
|
|
|
|
pingTimeout: 32 * time.Second,
|
|
|
|
retryWait: 32 * time.Second,
|
|
|
|
pingChan: make(chan byte),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type StateFunc func() StateFunc
|
2017-03-10 21:19:03 +00:00
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) Run() {
|
|
|
|
fn := w.runSSHCommand()
|
|
|
|
for {
|
|
|
|
fn = fn()
|
|
|
|
}
|
|
|
|
}
|
2017-03-10 21:19:03 +00:00
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) runSSHCommand() StateFunc {
|
|
|
|
w.pingClientPort = 32768 + rand.Intn(28233)
|
|
|
|
w.pingServerPort = 32768 + rand.Intn(28233)
|
|
|
|
|
|
|
|
w.cmdStr = "ssh " +
|
2017-03-10 21:19:03 +00:00
|
|
|
"-o ControlPersist=no -o ControlMaster=no -o GatewayPorts=yes " +
|
|
|
|
"-N -L " +
|
|
|
|
fmt.Sprintf("%v:localhost:%v -R %v:localhost:%v ",
|
2017-11-25 16:13:13 +00:00
|
|
|
w.pingClientPort,
|
|
|
|
w.pingClientPort,
|
|
|
|
w.pingClientPort,
|
|
|
|
w.pingServerPort) +
|
|
|
|
w.userCmd
|
2017-03-10 21:19:03 +00:00
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Running command: %v", w.cmdStr)
|
|
|
|
w.cmd = exec.Command("bash", "-c", w.cmdStr)
|
2017-03-10 21:19:03 +00:00
|
|
|
|
|
|
|
go func() {
|
2017-11-25 16:13:13 +00:00
|
|
|
output, err := w.cmd.CombinedOutput()
|
|
|
|
log.Printf("SSH command output: %v", string(output))
|
2017-03-10 21:19:03 +00:00
|
|
|
if err != nil {
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Failed to execute command: %v", err)
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
return w.startPingServer
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) sleepRetry() StateFunc {
|
|
|
|
log.Printf("Sleeping before retrying...")
|
|
|
|
w.cmd.Process.Kill()
|
|
|
|
if w.pingConn != nil {
|
|
|
|
w.pingConn.Close()
|
|
|
|
w.pingConn = nil
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
2017-11-25 16:13:13 +00:00
|
|
|
if w.pingListener != nil {
|
|
|
|
w.pingListener.Close()
|
|
|
|
w.pingListener = nil
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
2017-11-25 16:13:13 +00:00
|
|
|
time.Sleep(w.retryWait)
|
|
|
|
return w.runSSHCommand
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) startPingServer() StateFunc {
|
|
|
|
addr := fmt.Sprintf("localhost:%v", w.pingServerPort)
|
|
|
|
log.Printf("Starting ping server on: %v", addr)
|
|
|
|
|
|
|
|
var err error
|
|
|
|
w.pingListener, err = net.Listen("tcp", addr)
|
2017-03-10 21:19:03 +00:00
|
|
|
if err != nil {
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Failed to create server listener: %v", err)
|
|
|
|
return w.sleepRetry
|
|
|
|
}
|
|
|
|
|
|
|
|
go w.runPingServer()
|
|
|
|
|
|
|
|
time.Sleep(w.connectWait)
|
|
|
|
|
|
|
|
return w.startPingClient
|
|
|
|
}
|
|
|
|
|
|
|
|
func (w *SSHWatcher) runPingServer() {
|
|
|
|
conn, err := w.pingListener.Accept()
|
|
|
|
if err != nil {
|
|
|
|
log.Printf("Failed to accept ping connection: %v", err)
|
2017-03-10 21:19:03 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
|
|
|
for {
|
|
|
|
_, err = conn.Read(buf)
|
|
|
|
if err != nil {
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Failed to read from ping connection: %v", err)
|
2017-03-10 21:19:03 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
select {
|
2017-11-25 16:13:13 +00:00
|
|
|
case w.pingChan <- buf[0]:
|
2017-03-10 21:19:03 +00:00
|
|
|
|
|
|
|
default:
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Ping channel full. Stopping ping server.")
|
2017-03-10 21:19:03 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) startPingClient() StateFunc {
|
|
|
|
addr := fmt.Sprintf("localhost:%v", w.pingClientPort)
|
|
|
|
log.Printf("Starting ping client on: %v", addr)
|
2017-03-10 21:19:03 +00:00
|
|
|
|
|
|
|
var err error
|
2017-11-25 16:13:13 +00:00
|
|
|
w.pingConn, err = net.DialTimeout("tcp", addr, w.pingInterval)
|
2017-03-10 21:19:03 +00:00
|
|
|
if err != nil {
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Failed to dial ping client port: %v", err)
|
|
|
|
return w.sleepRetry
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
go w.runPingClient()
|
2017-03-10 21:19:03 +00:00
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
return w.pingLoop
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) runPingClient() {
|
2017-03-10 21:19:03 +00:00
|
|
|
// Send pings.
|
|
|
|
for {
|
|
|
|
// Set timeout.
|
2017-11-25 16:13:13 +00:00
|
|
|
err := w.pingConn.SetWriteDeadline(time.Now().Add(w.pingTimeout))
|
2017-03-10 21:19:03 +00:00
|
|
|
if err != nil {
|
2017-11-25 16:13:13 +00:00
|
|
|
log.Printf("Failed to set ping client write deadline: %v", err)
|
2017-03-10 21:19:03 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// Write ping data.
|
2017-11-25 16:13:13 +00:00
|
|
|
if _, err = w.pingConn.Write([]byte("1")); err != nil {
|
|
|
|
log.Printf("Failed to write ping data: %v", err)
|
2017-03-10 21:19:03 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
time.Sleep(w.pingInterval)
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-11-25 16:13:13 +00:00
|
|
|
func (w *SSHWatcher) pingLoop() StateFunc {
|
2017-03-10 21:19:03 +00:00
|
|
|
for {
|
|
|
|
select {
|
2017-11-25 16:13:13 +00:00
|
|
|
case <-w.pingChan:
|
|
|
|
log.Printf("Ping")
|
|
|
|
case <-time.After(w.pingTimeout):
|
|
|
|
log.Printf("Timed out waiting for ping.")
|
|
|
|
return w.sleepRetry
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
rand.Seed(time.Now().UnixNano())
|
2017-11-25 16:13:13 +00:00
|
|
|
NewSSHWatcher().Run()
|
2017-03-10 21:19:03 +00:00
|
|
|
}
|