A bit of refactoring.
parent
906482e36c
commit
357d2469d8
180
goautossh.go
180
goautossh.go
|
@ -2,19 +2,16 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.crumpington.com/public/golib/clog"
|
||||
)
|
||||
|
||||
var log clog.Logger
|
||||
|
||||
type Config struct {
|
||||
type SSHWatcher struct {
|
||||
// userCmd is appended to the ssh command line. We additionally add
|
||||
// commands to make a forwarding loop to monitor the connection.
|
||||
userCmd string
|
||||
|
@ -37,55 +34,91 @@ type Config struct {
|
|||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
type StateFunc func(*Config) StateFunc
|
||||
func NewSSHWatcher() *SSHWatcher {
|
||||
return &SSHWatcher{
|
||||
userCmd: strings.Join(os.Args[1:], " "),
|
||||
connectWait: 8 * time.Second,
|
||||
pingInterval: 8 * time.Second,
|
||||
pingTimeout: 32 * time.Second,
|
||||
retryWait: 32 * time.Second,
|
||||
pingChan: make(chan byte),
|
||||
}
|
||||
}
|
||||
|
||||
func runSshCommand(conf *Config) StateFunc {
|
||||
conf.pingClientPort = 32768 + rand.Intn(28233)
|
||||
conf.pingServerPort = 32768 + rand.Intn(28233)
|
||||
type StateFunc func() StateFunc
|
||||
|
||||
conf.cmdStr = "ssh " +
|
||||
func (w *SSHWatcher) Run() {
|
||||
fn := w.runSSHCommand()
|
||||
for {
|
||||
fn = fn()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *SSHWatcher) runSSHCommand() StateFunc {
|
||||
w.pingClientPort = 32768 + rand.Intn(28233)
|
||||
w.pingServerPort = 32768 + rand.Intn(28233)
|
||||
|
||||
w.cmdStr = "ssh " +
|
||||
"-o ControlPersist=no -o ControlMaster=no -o GatewayPorts=yes " +
|
||||
"-N -L " +
|
||||
fmt.Sprintf("%v:localhost:%v -R %v:localhost:%v ",
|
||||
conf.pingClientPort,
|
||||
conf.pingClientPort,
|
||||
conf.pingClientPort,
|
||||
conf.pingServerPort) +
|
||||
conf.userCmd
|
||||
w.pingClientPort,
|
||||
w.pingClientPort,
|
||||
w.pingClientPort,
|
||||
w.pingServerPort) +
|
||||
w.userCmd
|
||||
|
||||
log.Msg("Running command: %v", conf.cmdStr)
|
||||
conf.cmd = exec.Command("bash", "-c", conf.cmdStr)
|
||||
log.Printf("Running command: %v", w.cmdStr)
|
||||
w.cmd = exec.Command("bash", "-c", w.cmdStr)
|
||||
|
||||
go func() {
|
||||
output, err := conf.cmd.CombinedOutput()
|
||||
log.Msg("SSH command output: %v", string(output))
|
||||
output, err := w.cmd.CombinedOutput()
|
||||
log.Printf("SSH command output: %v", string(output))
|
||||
if err != nil {
|
||||
log.Err(err, "When executing SSH command")
|
||||
log.Printf("Failed to execute command: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return startPingServer
|
||||
return w.startPingServer
|
||||
}
|
||||
|
||||
func sleepRetry(conf *Config) StateFunc {
|
||||
log.Msg("Sleeping before retrying...")
|
||||
conf.cmd.Process.Kill()
|
||||
if conf.pingConn != nil {
|
||||
conf.pingConn.Close()
|
||||
conf.pingConn = nil
|
||||
func (w *SSHWatcher) sleepRetry() StateFunc {
|
||||
log.Printf("Sleeping before retrying...")
|
||||
w.cmd.Process.Kill()
|
||||
if w.pingConn != nil {
|
||||
w.pingConn.Close()
|
||||
w.pingConn = nil
|
||||
}
|
||||
if conf.pingListener != nil {
|
||||
conf.pingListener.Close()
|
||||
conf.pingListener = nil
|
||||
if w.pingListener != nil {
|
||||
w.pingListener.Close()
|
||||
w.pingListener = nil
|
||||
}
|
||||
time.Sleep(conf.retryWait)
|
||||
return runSshCommand
|
||||
time.Sleep(w.retryWait)
|
||||
return w.runSSHCommand
|
||||
}
|
||||
|
||||
func runPingServer(l net.Listener, pingChan chan byte) {
|
||||
conn, err := l.Accept()
|
||||
func (w *SSHWatcher) startPingServer() StateFunc {
|
||||
addr := fmt.Sprintf("localhost:%v", w.pingServerPort)
|
||||
log.Printf("Starting ping server on: %v", addr)
|
||||
|
||||
var err error
|
||||
w.pingListener, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Err(err, "When accepting ping connection")
|
||||
log.Printf("Failed to create server listener: %v", err)
|
||||
return w.sleepRetry
|
||||
}
|
||||
|
||||
go w.runPingServer()
|
||||
|
||||
time.Sleep(w.connectWait)
|
||||
|
||||
return w.startPingClient
|
||||
}
|
||||
|
||||
func (w *SSHWatcher) runPingServer() {
|
||||
conn, err := w.pingListener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Failed to accept ping connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -94,98 +127,69 @@ func runPingServer(l net.Listener, pingChan chan byte) {
|
|||
for {
|
||||
_, err = conn.Read(buf)
|
||||
if err != nil {
|
||||
log.Err(err, "When reading from ping connection")
|
||||
log.Printf("Failed to read from ping connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case pingChan <- buf[0]:
|
||||
case w.pingChan <- buf[0]:
|
||||
|
||||
default:
|
||||
log.Msg("Ping channel full. Stopping ping server.")
|
||||
log.Printf("Ping channel full. Stopping ping server.")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func startPingServer(conf *Config) StateFunc {
|
||||
addr := fmt.Sprintf("localhost:%v", conf.pingServerPort)
|
||||
log.Msg("Starting ping server on: %v", addr)
|
||||
func (w *SSHWatcher) startPingClient() StateFunc {
|
||||
addr := fmt.Sprintf("localhost:%v", w.pingClientPort)
|
||||
log.Printf("Starting ping client on: %v", addr)
|
||||
|
||||
var err error
|
||||
conf.pingListener, err = net.Listen("tcp", addr)
|
||||
w.pingConn, err = net.DialTimeout("tcp", addr, w.pingInterval)
|
||||
if err != nil {
|
||||
log.Err(err, "When creating server listener")
|
||||
return sleepRetry
|
||||
log.Printf("Failed to dial ping client port: %v", err)
|
||||
return w.sleepRetry
|
||||
}
|
||||
|
||||
go runPingServer(conf.pingListener, conf.pingChan)
|
||||
go w.runPingClient()
|
||||
|
||||
time.Sleep(conf.connectWait)
|
||||
|
||||
return startPingClient
|
||||
return w.pingLoop
|
||||
}
|
||||
|
||||
func runPingClient(conn net.Conn, pingTimeout, pingInterval time.Duration) {
|
||||
func (w *SSHWatcher) runPingClient() {
|
||||
// Send pings.
|
||||
for {
|
||||
// Set timeout.
|
||||
err := conn.SetWriteDeadline(time.Now().Add(pingTimeout))
|
||||
err := w.pingConn.SetWriteDeadline(time.Now().Add(w.pingTimeout))
|
||||
if err != nil {
|
||||
log.Err(err, "When setting ping client write deadline")
|
||||
log.Printf("Failed to set ping client write deadline: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write ping data.
|
||||
if _, err = conn.Write([]byte("1")); err != nil {
|
||||
log.Err(err, "When writing ping data")
|
||||
if _, err = w.pingConn.Write([]byte("1")); err != nil {
|
||||
log.Printf("Failed to write ping data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(pingInterval)
|
||||
time.Sleep(w.pingInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func startPingClient(conf *Config) StateFunc {
|
||||
addr := fmt.Sprintf("localhost:%v", conf.pingClientPort)
|
||||
log.Msg("Starting ping client on: %v", addr)
|
||||
|
||||
var err error
|
||||
conf.pingConn, err = net.DialTimeout("tcp", addr, conf.pingInterval)
|
||||
if err != nil {
|
||||
log.Err(err, "When dialing ping client port")
|
||||
return sleepRetry
|
||||
}
|
||||
|
||||
go runPingClient(conf.pingConn, conf.pingTimeout, conf.pingInterval)
|
||||
|
||||
return pingLoop
|
||||
}
|
||||
|
||||
func pingLoop(conf *Config) StateFunc {
|
||||
func (w *SSHWatcher) pingLoop() StateFunc {
|
||||
for {
|
||||
select {
|
||||
case <-conf.pingChan:
|
||||
log.Msg("Ping")
|
||||
case <-time.After(conf.pingTimeout):
|
||||
log.Msg("Timed out waiting for ping.")
|
||||
return sleepRetry
|
||||
case <-w.pingChan:
|
||||
log.Printf("Ping")
|
||||
case <-time.After(w.pingTimeout):
|
||||
log.Printf("Timed out waiting for ping.")
|
||||
return w.sleepRetry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
conf := Config{}
|
||||
conf.userCmd = strings.Join(os.Args[1:], " ")
|
||||
log = clog.New("AutoSSH: " + conf.userCmd)
|
||||
conf.connectWait = 8 * time.Second
|
||||
conf.pingInterval = 8 * time.Second
|
||||
conf.pingTimeout = 32 * time.Second
|
||||
conf.retryWait = 32 * time.Second
|
||||
conf.pingChan = make(chan byte)
|
||||
fn := runSshCommand(&conf)
|
||||
for {
|
||||
fn = fn(&conf)
|
||||
}
|
||||
NewSSHWatcher().Run()
|
||||
}
|
||||
|
|
Reference in New Issue