101 lines
3.1 KiB
Go
101 lines
3.1 KiB
Go
|
|
// sf-worker — SSH PTY daemon that executes sf headless runs on behalf of a remote orchestrator.
|
||
|
|
//
|
||
|
|
// Purpose: allow the SF orchestrator to dispatch autonomous unit attempts to remote hosts
|
||
|
|
// (GPU boxes, Windows machines, parallel workers) over SSH without requiring a full SF
|
||
|
|
// installation on the controlling machine.
|
||
|
|
//
|
||
|
|
// Usage:
|
||
|
|
//
|
||
|
|
// sf-worker [flags]
|
||
|
|
// --addr SSH listen address (default ":2222")
|
||
|
|
// --metrics-addr Prometheus /metrics address (default ":9100")
|
||
|
|
// --host-key Path to SSH host key (default "~/.sf/worker_host_key")
|
||
|
|
// --authorized-keys Path to authorized_keys file (default "~/.sf/worker_authorized_keys")
|
||
|
|
// --sf-bin Path to sf binary (default: resolved from $PATH)
|
||
|
|
// --max-sessions Maximum concurrent sessions (default 16)
|
||
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"flag"
|
||
|
|
"fmt"
|
||
|
|
"net/http"
|
||
|
|
"os"
|
||
|
|
"os/signal"
|
||
|
|
"path/filepath"
|
||
|
|
"syscall"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/charmbracelet/log"
|
||
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||
|
|
)
|
||
|
|
|
||
|
|
func defaultPath(rel string) string {
|
||
|
|
home, err := os.UserHomeDir()
|
||
|
|
if err != nil {
|
||
|
|
return rel
|
||
|
|
}
|
||
|
|
return filepath.Join(home, ".sf", rel)
|
||
|
|
}
|
||
|
|
|
||
|
|
func main() {
|
||
|
|
addr := flag.String("addr", ":2222", "SSH listen address")
|
||
|
|
metricsAddr := flag.String("metrics-addr", ":9100", "Prometheus metrics address")
|
||
|
|
hostKeyPath := flag.String("host-key", defaultPath("worker_host_key"), "SSH host key path (RSA/Ed25519 PEM)")
|
||
|
|
authorizedKeysPath := flag.String("authorized-keys", defaultPath("worker_authorized_keys"), "authorized_keys path")
|
||
|
|
sfBin := flag.String("sf-bin", "", "Path to sf binary (default: resolved from $PATH)")
|
||
|
|
maxSessions := flag.Int("max-sessions", 16, "Maximum concurrent sessions")
|
||
|
|
flag.Parse()
|
||
|
|
|
||
|
|
logger := log.NewWithOptions(os.Stderr, log.Options{
|
||
|
|
ReportTimestamp: true,
|
||
|
|
TimeFormat: time.RFC3339,
|
||
|
|
Level: log.InfoLevel,
|
||
|
|
})
|
||
|
|
|
||
|
|
srv, err := NewServer(ServerConfig{
|
||
|
|
Addr: *addr,
|
||
|
|
HostKeyPath: *hostKeyPath,
|
||
|
|
AuthorizedKeysPath: *authorizedKeysPath,
|
||
|
|
SFBin: *sfBin,
|
||
|
|
MaxSessions: *maxSessions,
|
||
|
|
Logger: logger,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
logger.Error("failed to create server", "err", err)
|
||
|
|
os.Exit(1)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Prometheus metrics endpoint.
|
||
|
|
metricsMux := http.NewServeMux()
|
||
|
|
metricsMux.Handle("/metrics", promhttp.Handler())
|
||
|
|
metricsMux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
|
||
|
|
fmt.Fprintln(w, "ok")
|
||
|
|
})
|
||
|
|
metricsServer := &http.Server{
|
||
|
|
Addr: *metricsAddr,
|
||
|
|
Handler: metricsMux,
|
||
|
|
}
|
||
|
|
go func() {
|
||
|
|
logger.Info("metrics server listening", "addr", *metricsAddr)
|
||
|
|
if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||
|
|
logger.Error("metrics server error", "err", err)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
logger.Info("sf-worker starting", "addr", *addr)
|
||
|
|
if err := srv.ListenAndServe(ctx); err != nil {
|
||
|
|
logger.Error("server error", "err", err)
|
||
|
|
os.Exit(1)
|
||
|
|
}
|
||
|
|
|
||
|
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
|
|
defer shutdownCancel()
|
||
|
|
_ = metricsServer.Shutdown(shutdownCtx)
|
||
|
|
|
||
|
|
logger.Info("sf-worker stopped")
|
||
|
|
}
|