singularity-forge/sf-worker/server.go

216 lines
5.5 KiB
Go
Raw Normal View History

// server.go — SSH server setup and connection acceptance.
//
// Purpose: accept SSH connections from the SF orchestrator, enforce key-based auth,
// and hand each session off to runSession for PTY execution.
package main
import (
"bufio"
"context"
"fmt"
"net"
"os"
"os/exec"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/charmbracelet/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/crypto/ssh"
)
var (
metricActiveSessions = promauto.NewGauge(prometheus.GaugeOpts{
Name: "sfworker_active_sessions",
Help: "Currently active SSH sessions.",
})
metricSessionsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "sfworker_sessions_total",
Help: "Total SSH sessions by outcome.",
}, []string{"outcome"})
metricSessionDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Name: "sfworker_session_duration_seconds",
Help: "Session duration in seconds.",
Buckets: prometheus.ExponentialBuckets(1, 2, 12),
})
)
// ServerConfig holds the sf-worker SSH server configuration.
type ServerConfig struct {
Addr string
HostKeyPath string
AuthorizedKeysPath string
// SFBin is the path to the sf binary. If empty, resolved from $PATH.
SFBin string
MaxSessions int
Logger *log.Logger
}
// Server is the sf-worker SSH server.
type Server struct {
cfg ServerConfig
sshConfig *ssh.ServerConfig
sfBin string
activeSess atomic.Int32
logger *log.Logger
}
// NewServer constructs and configures the SSH server.
func NewServer(cfg ServerConfig) (*Server, error) {
if cfg.MaxSessions <= 0 {
cfg.MaxSessions = 16
}
sfBin := cfg.SFBin
if sfBin == "" {
var err error
sfBin, err = exec.LookPath("sf")
if err != nil {
return nil, fmt.Errorf("sf binary not found in PATH (set --sf-bin): %w", err)
}
}
hostKey, err := loadOrGenerateHostKey(cfg.HostKeyPath)
if err != nil {
return nil, fmt.Errorf("host key: %w", err)
}
authorizedKeys, err := loadAuthorizedKeys(cfg.AuthorizedKeysPath)
if err != nil {
return nil, fmt.Errorf("authorized keys: %w", err)
}
sshConfig := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
fp := ssh.FingerprintSHA256(key)
if _, ok := authorizedKeys[fp]; ok {
return &ssh.Permissions{Extensions: map[string]string{"fp": fp}}, nil
}
return nil, fmt.Errorf("key not authorized: %s", fp)
},
}
sshConfig.AddHostKey(hostKey)
return &Server{
cfg: cfg,
sshConfig: sshConfig,
sfBin: sfBin,
logger: cfg.Logger,
}, nil
}
// ListenAndServe starts accepting SSH connections until ctx is cancelled.
func (s *Server) ListenAndServe(ctx context.Context) error {
ln, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("listen %s: %w", s.cfg.Addr, err)
}
s.logger.Info("SSH server listening", "addr", s.cfg.Addr, "sf-bin", s.sfBin)
var wg sync.WaitGroup
go func() {
<-ctx.Done()
_ = ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
break
}
s.logger.Warn("accept error", "err", err)
continue
}
if int(s.activeSess.Load()) >= s.cfg.MaxSessions {
s.logger.Warn("max sessions reached, rejecting connection", "remote", conn.RemoteAddr())
_ = conn.Close()
continue
}
wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
s.handleConn(ctx, c)
}(conn)
}
wg.Wait()
return nil
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
remote := conn.RemoteAddr().String()
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.logger.Warn("SSH handshake failed", "remote", remote, "err", err)
metricSessionsTotal.WithLabelValues("auth_failed").Inc()
return
}
defer sshConn.Close()
fp := sshConn.Permissions.Extensions["fp"]
s.logger.Info("new connection", "remote", remote, "fp", fp)
s.activeSess.Add(1)
metricActiveSessions.Inc()
defer func() {
s.activeSess.Add(-1)
metricActiveSessions.Dec()
metricSessionDuration.Observe(time.Since(start).Seconds())
}()
// Discard global requests.
go ssh.DiscardRequests(reqs)
for newChan := range chans {
if newChan.ChannelType() != "session" {
_ = newChan.Reject(ssh.UnknownChannelType, "only session channels accepted")
continue
}
ch, requests, err := newChan.Accept()
if err != nil {
s.logger.Warn("channel accept error", "err", err)
metricSessionsTotal.WithLabelValues("error").Inc()
return
}
outcome := s.runSession(ctx, ch, requests, remote, fp)
metricSessionsTotal.WithLabelValues(outcome).Inc()
s.logger.Info("session ended", "remote", remote, "outcome", outcome, "duration", time.Since(start).Round(time.Millisecond))
}
}
// loadAuthorizedKeys parses an OpenSSH authorized_keys file into a fingerprint→key map.
func loadAuthorizedKeys(path string) (map[string]ssh.PublicKey, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return map[string]ssh.PublicKey{}, nil // empty = no keys authorized (warn at connection time)
}
if err != nil {
return nil, err
}
defer f.Close()
keys := map[string]ssh.PublicKey{}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
pub, _, _, _, err := ssh.ParseAuthorizedKey([]byte(line))
if err != nil {
continue // skip malformed lines
}
keys[ssh.FingerprintSHA256(pub)] = pub
}
return keys, scanner.Err()
}