singularity-forge/sf-worker/main.go

101 lines
3.1 KiB
Go
Raw Normal View History

// sf-worker — SSH PTY daemon that executes sf headless runs on behalf of a remote orchestrator.
//
// Purpose: allow the SF orchestrator to dispatch autonomous unit attempts to remote hosts
// (GPU boxes, Windows machines, parallel workers) over SSH without requiring a full SF
// installation on the controlling machine.
//
// Usage:
//
// sf-worker [flags]
// --addr SSH listen address (default ":2222")
// --metrics-addr Prometheus /metrics address (default ":9100")
// --host-key Path to SSH host key (default "~/.sf/worker_host_key")
// --authorized-keys Path to authorized_keys file (default "~/.sf/worker_authorized_keys")
// --sf-bin Path to sf binary (default: resolved from $PATH)
// --max-sessions Maximum concurrent sessions (default 16)
package main
import (
"context"
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/charmbracelet/log"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func defaultPath(rel string) string {
home, err := os.UserHomeDir()
if err != nil {
return rel
}
return filepath.Join(home, ".sf", rel)
}
func main() {
addr := flag.String("addr", ":2222", "SSH listen address")
metricsAddr := flag.String("metrics-addr", ":9100", "Prometheus metrics address")
hostKeyPath := flag.String("host-key", defaultPath("worker_host_key"), "SSH host key path (RSA/Ed25519 PEM)")
authorizedKeysPath := flag.String("authorized-keys", defaultPath("worker_authorized_keys"), "authorized_keys path")
sfBin := flag.String("sf-bin", "", "Path to sf binary (default: resolved from $PATH)")
maxSessions := flag.Int("max-sessions", 16, "Maximum concurrent sessions")
flag.Parse()
logger := log.NewWithOptions(os.Stderr, log.Options{
ReportTimestamp: true,
TimeFormat: time.RFC3339,
Level: log.InfoLevel,
})
srv, err := NewServer(ServerConfig{
Addr: *addr,
HostKeyPath: *hostKeyPath,
AuthorizedKeysPath: *authorizedKeysPath,
SFBin: *sfBin,
MaxSessions: *maxSessions,
Logger: logger,
})
if err != nil {
logger.Error("failed to create server", "err", err)
os.Exit(1)
}
// Prometheus metrics endpoint.
metricsMux := http.NewServeMux()
metricsMux.Handle("/metrics", promhttp.Handler())
metricsMux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintln(w, "ok")
})
metricsServer := &http.Server{
Addr: *metricsAddr,
Handler: metricsMux,
}
go func() {
logger.Info("metrics server listening", "addr", *metricsAddr)
if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("metrics server error", "err", err)
}
}()
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
logger.Info("sf-worker starting", "addr", *addr)
if err := srv.ListenAndServe(ctx); err != nil {
logger.Error("server error", "err", err)
os.Exit(1)
}
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
_ = metricsServer.Shutdown(shutdownCtx)
logger.Info("sf-worker stopped")
}