// server.go — SSH server setup and connection acceptance. // // Purpose: accept SSH connections from the SF orchestrator, enforce key-based auth, // and hand each session off to runSession for PTY execution. package main import ( "bufio" "context" "fmt" "net" "os" "os/exec" "strings" "sync" "sync/atomic" "time" "github.com/charmbracelet/log" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "golang.org/x/crypto/ssh" ) var ( metricActiveSessions = promauto.NewGauge(prometheus.GaugeOpts{ Name: "sfworker_active_sessions", Help: "Currently active SSH sessions.", }) metricSessionsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ Name: "sfworker_sessions_total", Help: "Total SSH sessions by outcome.", }, []string{"outcome"}) metricSessionDuration = promauto.NewHistogram(prometheus.HistogramOpts{ Name: "sfworker_session_duration_seconds", Help: "Session duration in seconds.", Buckets: prometheus.ExponentialBuckets(1, 2, 12), }) ) // ServerConfig holds the sf-worker SSH server configuration. type ServerConfig struct { Addr string HostKeyPath string AuthorizedKeysPath string // SFBin is the path to the sf binary. If empty, resolved from $PATH. SFBin string MaxSessions int Logger *log.Logger } // Server is the sf-worker SSH server. type Server struct { cfg ServerConfig sshConfig *ssh.ServerConfig sfBin string activeSess atomic.Int32 logger *log.Logger } // NewServer constructs and configures the SSH server. func NewServer(cfg ServerConfig) (*Server, error) { if cfg.MaxSessions <= 0 { cfg.MaxSessions = 16 } sfBin := cfg.SFBin if sfBin == "" { var err error sfBin, err = exec.LookPath("sf") if err != nil { return nil, fmt.Errorf("sf binary not found in PATH (set --sf-bin): %w", err) } } hostKey, err := loadOrGenerateHostKey(cfg.HostKeyPath) if err != nil { return nil, fmt.Errorf("host key: %w", err) } authorizedKeys, err := loadAuthorizedKeys(cfg.AuthorizedKeysPath) if err != nil { return nil, fmt.Errorf("authorized keys: %w", err) } sshConfig := &ssh.ServerConfig{ PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { fp := ssh.FingerprintSHA256(key) if _, ok := authorizedKeys[fp]; ok { return &ssh.Permissions{Extensions: map[string]string{"fp": fp}}, nil } return nil, fmt.Errorf("key not authorized: %s", fp) }, } sshConfig.AddHostKey(hostKey) return &Server{ cfg: cfg, sshConfig: sshConfig, sfBin: sfBin, logger: cfg.Logger, }, nil } // ListenAndServe starts accepting SSH connections until ctx is cancelled. func (s *Server) ListenAndServe(ctx context.Context) error { ln, err := net.Listen("tcp", s.cfg.Addr) if err != nil { return fmt.Errorf("listen %s: %w", s.cfg.Addr, err) } s.logger.Info("SSH server listening", "addr", s.cfg.Addr, "sf-bin", s.sfBin) var wg sync.WaitGroup go func() { <-ctx.Done() _ = ln.Close() }() for { conn, err := ln.Accept() if err != nil { if ctx.Err() != nil { break } s.logger.Warn("accept error", "err", err) continue } if int(s.activeSess.Load()) >= s.cfg.MaxSessions { s.logger.Warn("max sessions reached, rejecting connection", "remote", conn.RemoteAddr()) _ = conn.Close() continue } wg.Add(1) go func(c net.Conn) { defer wg.Done() s.handleConn(ctx, c) }(conn) } wg.Wait() return nil } func (s *Server) handleConn(ctx context.Context, conn net.Conn) { defer conn.Close() start := time.Now() remote := conn.RemoteAddr().String() sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig) if err != nil { s.logger.Warn("SSH handshake failed", "remote", remote, "err", err) metricSessionsTotal.WithLabelValues("auth_failed").Inc() return } defer sshConn.Close() fp := sshConn.Permissions.Extensions["fp"] s.logger.Info("new connection", "remote", remote, "fp", fp) s.activeSess.Add(1) metricActiveSessions.Inc() defer func() { s.activeSess.Add(-1) metricActiveSessions.Dec() metricSessionDuration.Observe(time.Since(start).Seconds()) }() // Discard global requests. go ssh.DiscardRequests(reqs) for newChan := range chans { if newChan.ChannelType() != "session" { _ = newChan.Reject(ssh.UnknownChannelType, "only session channels accepted") continue } ch, requests, err := newChan.Accept() if err != nil { s.logger.Warn("channel accept error", "err", err) metricSessionsTotal.WithLabelValues("error").Inc() return } outcome := s.runSession(ctx, ch, requests, remote, fp) metricSessionsTotal.WithLabelValues(outcome).Inc() s.logger.Info("session ended", "remote", remote, "outcome", outcome, "duration", time.Since(start).Round(time.Millisecond)) } } // loadAuthorizedKeys parses an OpenSSH authorized_keys file into a fingerprint→key map. func loadAuthorizedKeys(path string) (map[string]ssh.PublicKey, error) { f, err := os.Open(path) if os.IsNotExist(err) { return map[string]ssh.PublicKey{}, nil // empty = no keys authorized (warn at connection time) } if err != nil { return nil, err } defer f.Close() keys := map[string]ssh.PublicKey{} scanner := bufio.NewScanner(f) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } pub, _, _, _, err := ssh.ParseAuthorizedKey([]byte(line)) if err != nil { continue // skip malformed lines } keys[ssh.FingerprintSHA256(pub)] = pub } return keys, scanner.Err() }