A bit of refactoring.

master
J. David Lee 2017-11-25 17:13:13 +01:00
parent 906482e36c
commit 357d2469d8
1 changed files with 92 additions and 88 deletions

View File

@ -2,19 +2,16 @@ package main
import ( import (
"fmt" "fmt"
"log"
"math/rand" "math/rand"
"net" "net"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"time" "time"
"git.crumpington.com/public/golib/clog"
) )
var log clog.Logger type SSHWatcher struct {
type Config struct {
// userCmd is appended to the ssh command line. We additionally add // userCmd is appended to the ssh command line. We additionally add
// commands to make a forwarding loop to monitor the connection. // commands to make a forwarding loop to monitor the connection.
userCmd string userCmd string
@ -37,55 +34,91 @@ type Config struct {
cmd *exec.Cmd cmd *exec.Cmd
} }
type StateFunc func(*Config) StateFunc func NewSSHWatcher() *SSHWatcher {
return &SSHWatcher{
userCmd: strings.Join(os.Args[1:], " "),
connectWait: 8 * time.Second,
pingInterval: 8 * time.Second,
pingTimeout: 32 * time.Second,
retryWait: 32 * time.Second,
pingChan: make(chan byte),
}
}
func runSshCommand(conf *Config) StateFunc { type StateFunc func() StateFunc
conf.pingClientPort = 32768 + rand.Intn(28233)
conf.pingServerPort = 32768 + rand.Intn(28233)
conf.cmdStr = "ssh " + func (w *SSHWatcher) Run() {
fn := w.runSSHCommand()
for {
fn = fn()
}
}
func (w *SSHWatcher) runSSHCommand() StateFunc {
w.pingClientPort = 32768 + rand.Intn(28233)
w.pingServerPort = 32768 + rand.Intn(28233)
w.cmdStr = "ssh " +
"-o ControlPersist=no -o ControlMaster=no -o GatewayPorts=yes " + "-o ControlPersist=no -o ControlMaster=no -o GatewayPorts=yes " +
"-N -L " + "-N -L " +
fmt.Sprintf("%v:localhost:%v -R %v:localhost:%v ", fmt.Sprintf("%v:localhost:%v -R %v:localhost:%v ",
conf.pingClientPort, w.pingClientPort,
conf.pingClientPort, w.pingClientPort,
conf.pingClientPort, w.pingClientPort,
conf.pingServerPort) + w.pingServerPort) +
conf.userCmd w.userCmd
log.Msg("Running command: %v", conf.cmdStr) log.Printf("Running command: %v", w.cmdStr)
conf.cmd = exec.Command("bash", "-c", conf.cmdStr) w.cmd = exec.Command("bash", "-c", w.cmdStr)
go func() { go func() {
output, err := conf.cmd.CombinedOutput() output, err := w.cmd.CombinedOutput()
log.Msg("SSH command output: %v", string(output)) log.Printf("SSH command output: %v", string(output))
if err != nil { if err != nil {
log.Err(err, "When executing SSH command") log.Printf("Failed to execute command: %v", err)
} }
}() }()
return startPingServer return w.startPingServer
} }
func sleepRetry(conf *Config) StateFunc { func (w *SSHWatcher) sleepRetry() StateFunc {
log.Msg("Sleeping before retrying...") log.Printf("Sleeping before retrying...")
conf.cmd.Process.Kill() w.cmd.Process.Kill()
if conf.pingConn != nil { if w.pingConn != nil {
conf.pingConn.Close() w.pingConn.Close()
conf.pingConn = nil w.pingConn = nil
} }
if conf.pingListener != nil { if w.pingListener != nil {
conf.pingListener.Close() w.pingListener.Close()
conf.pingListener = nil w.pingListener = nil
} }
time.Sleep(conf.retryWait) time.Sleep(w.retryWait)
return runSshCommand return w.runSSHCommand
} }
func runPingServer(l net.Listener, pingChan chan byte) { func (w *SSHWatcher) startPingServer() StateFunc {
conn, err := l.Accept() addr := fmt.Sprintf("localhost:%v", w.pingServerPort)
log.Printf("Starting ping server on: %v", addr)
var err error
w.pingListener, err = net.Listen("tcp", addr)
if err != nil { if err != nil {
log.Err(err, "When accepting ping connection") log.Printf("Failed to create server listener: %v", err)
return w.sleepRetry
}
go w.runPingServer()
time.Sleep(w.connectWait)
return w.startPingClient
}
func (w *SSHWatcher) runPingServer() {
conn, err := w.pingListener.Accept()
if err != nil {
log.Printf("Failed to accept ping connection: %v", err)
return return
} }
@ -94,98 +127,69 @@ func runPingServer(l net.Listener, pingChan chan byte) {
for { for {
_, err = conn.Read(buf) _, err = conn.Read(buf)
if err != nil { if err != nil {
log.Err(err, "When reading from ping connection") log.Printf("Failed to read from ping connection: %v", err)
return return
} }
select { select {
case pingChan <- buf[0]: case w.pingChan <- buf[0]:
default: default:
log.Msg("Ping channel full. Stopping ping server.") log.Printf("Ping channel full. Stopping ping server.")
return return
} }
} }
} }
func startPingServer(conf *Config) StateFunc { func (w *SSHWatcher) startPingClient() StateFunc {
addr := fmt.Sprintf("localhost:%v", conf.pingServerPort) addr := fmt.Sprintf("localhost:%v", w.pingClientPort)
log.Msg("Starting ping server on: %v", addr) log.Printf("Starting ping client on: %v", addr)
var err error var err error
conf.pingListener, err = net.Listen("tcp", addr) w.pingConn, err = net.DialTimeout("tcp", addr, w.pingInterval)
if err != nil { if err != nil {
log.Err(err, "When creating server listener") log.Printf("Failed to dial ping client port: %v", err)
return sleepRetry return w.sleepRetry
} }
go runPingServer(conf.pingListener, conf.pingChan) go w.runPingClient()
time.Sleep(conf.connectWait) return w.pingLoop
return startPingClient
} }
func runPingClient(conn net.Conn, pingTimeout, pingInterval time.Duration) { func (w *SSHWatcher) runPingClient() {
// Send pings. // Send pings.
for { for {
// Set timeout. // Set timeout.
err := conn.SetWriteDeadline(time.Now().Add(pingTimeout)) err := w.pingConn.SetWriteDeadline(time.Now().Add(w.pingTimeout))
if err != nil { if err != nil {
log.Err(err, "When setting ping client write deadline") log.Printf("Failed to set ping client write deadline: %v", err)
return return
} }
// Write ping data. // Write ping data.
if _, err = conn.Write([]byte("1")); err != nil { if _, err = w.pingConn.Write([]byte("1")); err != nil {
log.Err(err, "When writing ping data") log.Printf("Failed to write ping data: %v", err)
return return
} }
time.Sleep(pingInterval) time.Sleep(w.pingInterval)
} }
} }
func startPingClient(conf *Config) StateFunc { func (w *SSHWatcher) pingLoop() StateFunc {
addr := fmt.Sprintf("localhost:%v", conf.pingClientPort)
log.Msg("Starting ping client on: %v", addr)
var err error
conf.pingConn, err = net.DialTimeout("tcp", addr, conf.pingInterval)
if err != nil {
log.Err(err, "When dialing ping client port")
return sleepRetry
}
go runPingClient(conf.pingConn, conf.pingTimeout, conf.pingInterval)
return pingLoop
}
func pingLoop(conf *Config) StateFunc {
for { for {
select { select {
case <-conf.pingChan: case <-w.pingChan:
log.Msg("Ping") log.Printf("Ping")
case <-time.After(conf.pingTimeout): case <-time.After(w.pingTimeout):
log.Msg("Timed out waiting for ping.") log.Printf("Timed out waiting for ping.")
return sleepRetry return w.sleepRetry
} }
} }
} }
func main() { func main() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
conf := Config{} NewSSHWatcher().Run()
conf.userCmd = strings.Join(os.Args[1:], " ")
log = clog.New("AutoSSH: " + conf.userCmd)
conf.connectWait = 8 * time.Second
conf.pingInterval = 8 * time.Second
conf.pingTimeout = 32 * time.Second
conf.retryWait = 32 * time.Second
conf.pingChan = make(chan byte)
fn := runSshCommand(&conf)
for {
fn = fn(&conf)
}
} }