- Delete ghost package packages/pi-agent-core (no dist, no consumers, TS build errors; JS source sf-db.js had 3 commits not mirrored in TS) - Remove build:pi-agent-core from root package.json build:pi pipeline - Merge all models from MODEL_COST_PER_1K_INPUT into BUNDLED_COST_TABLE (model-cost-table.js is now the single canonical cost source) - Remove duplicate MODEL_COST_PER_1K_INPUT object and getModelCost() from model-router.js; use lookupModelCost() from model-cost-table.js - Replace hand-rolled isTransientNetworkError in preferences-models.js with delegation to classifyError() in error-classifier.js Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
215 lines
5.5 KiB
Go
215 lines
5.5 KiB
Go
// server.go — SSH server setup and connection acceptance.
|
|
//
|
|
// Purpose: accept SSH connections from the SF orchestrator, enforce key-based auth,
|
|
// and hand each session off to runSession for PTY execution.
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var (
|
|
metricActiveSessions = promauto.NewGauge(prometheus.GaugeOpts{
|
|
Name: "sfworker_active_sessions",
|
|
Help: "Currently active SSH sessions.",
|
|
})
|
|
metricSessionsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Name: "sfworker_sessions_total",
|
|
Help: "Total SSH sessions by outcome.",
|
|
}, []string{"outcome"})
|
|
metricSessionDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
|
Name: "sfworker_session_duration_seconds",
|
|
Help: "Session duration in seconds.",
|
|
Buckets: prometheus.ExponentialBuckets(1, 2, 12),
|
|
})
|
|
)
|
|
|
|
// ServerConfig holds the sf-worker SSH server configuration.
|
|
type ServerConfig struct {
|
|
Addr string
|
|
HostKeyPath string
|
|
AuthorizedKeysPath string
|
|
// SFBin is the path to the sf binary. If empty, resolved from $PATH.
|
|
SFBin string
|
|
MaxSessions int
|
|
Logger *log.Logger
|
|
}
|
|
|
|
// Server is the sf-worker SSH server.
|
|
type Server struct {
|
|
cfg ServerConfig
|
|
sshConfig *ssh.ServerConfig
|
|
sfBin string
|
|
activeSess atomic.Int32
|
|
logger *log.Logger
|
|
}
|
|
|
|
// NewServer constructs and configures the SSH server.
|
|
func NewServer(cfg ServerConfig) (*Server, error) {
|
|
if cfg.MaxSessions <= 0 {
|
|
cfg.MaxSessions = 16
|
|
}
|
|
|
|
sfBin := cfg.SFBin
|
|
if sfBin == "" {
|
|
var err error
|
|
sfBin, err = exec.LookPath("sf")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sf binary not found in PATH (set --sf-bin): %w", err)
|
|
}
|
|
}
|
|
|
|
hostKey, err := loadOrGenerateHostKey(cfg.HostKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("host key: %w", err)
|
|
}
|
|
|
|
authorizedKeys, err := loadAuthorizedKeys(cfg.AuthorizedKeysPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authorized keys: %w", err)
|
|
}
|
|
|
|
sshConfig := &ssh.ServerConfig{
|
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
fp := ssh.FingerprintSHA256(key)
|
|
if _, ok := authorizedKeys[fp]; ok {
|
|
return &ssh.Permissions{Extensions: map[string]string{"fp": fp}}, nil
|
|
}
|
|
return nil, fmt.Errorf("key not authorized: %s", fp)
|
|
},
|
|
}
|
|
sshConfig.AddHostKey(hostKey)
|
|
|
|
return &Server{
|
|
cfg: cfg,
|
|
sshConfig: sshConfig,
|
|
sfBin: sfBin,
|
|
logger: cfg.Logger,
|
|
}, nil
|
|
}
|
|
|
|
// ListenAndServe starts accepting SSH connections until ctx is cancelled.
|
|
func (s *Server) ListenAndServe(ctx context.Context) error {
|
|
ln, err := net.Listen("tcp", s.cfg.Addr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen %s: %w", s.cfg.Addr, err)
|
|
}
|
|
s.logger.Info("SSH server listening", "addr", s.cfg.Addr, "sf-bin", s.sfBin)
|
|
|
|
var wg sync.WaitGroup
|
|
go func() {
|
|
<-ctx.Done()
|
|
_ = ln.Close()
|
|
}()
|
|
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
s.logger.Warn("accept error", "err", err)
|
|
continue
|
|
}
|
|
|
|
if int(s.activeSess.Load()) >= s.cfg.MaxSessions {
|
|
s.logger.Warn("max sessions reached, rejecting connection", "remote", conn.RemoteAddr())
|
|
_ = conn.Close()
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(c net.Conn) {
|
|
defer wg.Done()
|
|
s.handleConn(ctx, c)
|
|
}(conn)
|
|
}
|
|
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handleConn(ctx context.Context, conn net.Conn) {
|
|
defer conn.Close()
|
|
start := time.Now()
|
|
remote := conn.RemoteAddr().String()
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
if err != nil {
|
|
s.logger.Warn("SSH handshake failed", "remote", remote, "err", err)
|
|
metricSessionsTotal.WithLabelValues("auth_failed").Inc()
|
|
return
|
|
}
|
|
defer sshConn.Close()
|
|
|
|
fp := sshConn.Permissions.Extensions["fp"]
|
|
s.logger.Info("new connection", "remote", remote, "fp", fp)
|
|
s.activeSess.Add(1)
|
|
metricActiveSessions.Inc()
|
|
defer func() {
|
|
s.activeSess.Add(-1)
|
|
metricActiveSessions.Dec()
|
|
metricSessionDuration.Observe(time.Since(start).Seconds())
|
|
}()
|
|
|
|
// Discard global requests.
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
for newChan := range chans {
|
|
if newChan.ChannelType() != "session" {
|
|
_ = newChan.Reject(ssh.UnknownChannelType, "only session channels accepted")
|
|
continue
|
|
}
|
|
ch, requests, err := newChan.Accept()
|
|
if err != nil {
|
|
s.logger.Warn("channel accept error", "err", err)
|
|
metricSessionsTotal.WithLabelValues("error").Inc()
|
|
return
|
|
}
|
|
|
|
outcome := s.runSession(ctx, ch, requests, remote, fp)
|
|
metricSessionsTotal.WithLabelValues(outcome).Inc()
|
|
s.logger.Info("session ended", "remote", remote, "outcome", outcome, "duration", time.Since(start).Round(time.Millisecond))
|
|
}
|
|
}
|
|
|
|
// loadAuthorizedKeys parses an OpenSSH authorized_keys file into a fingerprint→key map.
|
|
func loadAuthorizedKeys(path string) (map[string]ssh.PublicKey, error) {
|
|
f, err := os.Open(path)
|
|
if os.IsNotExist(err) {
|
|
return map[string]ssh.PublicKey{}, nil // empty = no keys authorized (warn at connection time)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
|
|
keys := map[string]ssh.PublicKey{}
|
|
scanner := bufio.NewScanner(f)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
pub, _, _, _, err := ssh.ParseAuthorizedKey([]byte(line))
|
|
if err != nil {
|
|
continue // skip malformed lines
|
|
}
|
|
keys[ssh.FingerprintSHA256(pub)] = pub
|
|
}
|
|
return keys, scanner.Err()
|
|
}
|