// session.go — per-SSH-session PTY execution of sf headless. // // Purpose: spawn `sf headless ` in a real PTY for each authorized SSH session, // wire the PTY I/O to the SSH channel, handle window-resize requests, and clean up // on disconnect. package main import ( "context" "fmt" "io" "os" "os/exec" "strings" "sync" "github.com/creack/pty" "golang.org/x/crypto/ssh" ) // allowedSubcommands are the sf subcommands a worker session may execute. // Restricting to headless prevents the orchestrator key from being used for arbitrary execution. var allowedSubcommands = map[string]bool{ "headless": true, "version": true, "--version": true, } // runSession handles a single SSH session channel: waits for an exec request, // validates the command, spawns it in a PTY, and streams I/O until completion. // Returns an outcome label for metrics: "ok", "rejected", "error". func (s *Server) runSession(ctx context.Context, ch ssh.Channel, requests <-chan *ssh.Request, remote, fp string) string { defer ch.Close() type execReq struct { cmd string ptyW uint32 ptyH uint32 hasPTY bool } var pending execReq // Collect session requests until we get "exec" or the channel closes. for req := range requests { switch req.Type { case "pty-req": // https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 if len(req.Payload) < 4 { _ = req.Reply(false, nil) continue } // term string length prefix termLen := int(req.Payload[0])<<24 | int(req.Payload[1])<<16 | int(req.Payload[2])<<8 | int(req.Payload[3]) offset := 4 + termLen if len(req.Payload) < offset+8 { _ = req.Reply(false, nil) continue } pending.ptyW = uint32(req.Payload[offset])<<24 | uint32(req.Payload[offset+1])<<16 | uint32(req.Payload[offset+2])<<8 | uint32(req.Payload[offset+3]) pending.ptyH = uint32(req.Payload[offset+4])<<24 | uint32(req.Payload[offset+5])<<16 | uint32(req.Payload[offset+6])<<8 | uint32(req.Payload[offset+7]) pending.hasPTY = true _ = req.Reply(true, nil) case "window-change": // Handled after PTY is started (ignore if no PTY yet). _ = req.Reply(false, nil) case "exec": if len(req.Payload) < 4 { _ = req.Reply(false, nil) return "rejected" } cmdLen := int(req.Payload[0])<<24 | int(req.Payload[1])<<16 | int(req.Payload[2])<<8 | int(req.Payload[3]) if len(req.Payload) < 4+cmdLen { _ = req.Reply(false, nil) return "rejected" } pending.cmd = string(req.Payload[4 : 4+cmdLen]) _ = req.Reply(true, nil) outcome := s.execCommand(ctx, ch, requests, pending.cmd, pending.ptyW, pending.ptyH, remote, fp) return outcome case "shell": // No interactive shell — reject. _ = req.Reply(false, nil) sendExitStatus(ch, 1) return "rejected" default: if req.WantReply { _ = req.Reply(false, nil) } } } return "ok" } // execCommand validates and executes the requested command in a PTY. func (s *Server) execCommand(ctx context.Context, ch ssh.Channel, requests <-chan *ssh.Request, cmdStr string, ptyW, ptyH uint32, remote, fp string) string { args := strings.Fields(cmdStr) if len(args) == 0 { sendExitStatus(ch, 1) return "rejected" } // Validate: first arg must be "sf" (or the sf-bin basename), second must be an allowed subcommand. sfBase := s.sfBin for i := len(sfBase) - 1; i >= 0; i-- { if sfBase[i] == '/' || sfBase[i] == '\\' { sfBase = sfBase[i+1:] break } } start := 0 if args[0] == sfBase || args[0] == "sf" { start = 1 } if start >= len(args) || !allowedSubcommands[args[start]] { s.logger.Warn("rejected command", "remote", remote, "cmd", cmdStr) fmt.Fprintf(ch, "sf-worker: command not allowed: %q\r\n", cmdStr) sendExitStatus(ch, 1) return "rejected" } // Build the actual command: replace the leading "sf" with the real binary path. execArgs := append([]string{s.sfBin}, args[start:]...) cmd := exec.CommandContext(ctx, execArgs[0], execArgs[1:]...) cmd.Env = append(os.Environ(), "SF_WORKER=1", fmt.Sprintf("SF_WORKER_CLIENT_FP=%s", fp)) s.logger.Info("exec", "remote", remote, "cmd", execArgs) // Start with PTY. ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{ Cols: uint16(ptyW), Rows: uint16(ptyH), }) if err != nil { s.logger.Error("pty start failed", "err", err) fmt.Fprintf(ch, "sf-worker: failed to start: %v\r\n", err) sendExitStatus(ch, 1) return "error" } defer func() { _ = ptmx.Close() }() // Handle subsequent window-change requests in background. go func() { for req := range requests { if req.Type == "window-change" && len(req.Payload) >= 8 { w := uint32(req.Payload[0])<<24 | uint32(req.Payload[1])<<16 | uint32(req.Payload[2])<<8 | uint32(req.Payload[3]) h := uint32(req.Payload[4])<<24 | uint32(req.Payload[5])<<16 | uint32(req.Payload[6])<<8 | uint32(req.Payload[7]) _ = pty.Setsize(ptmx, &pty.Winsize{Cols: uint16(w), Rows: uint16(h)}) } if req.WantReply { _ = req.Reply(false, nil) } } }() // Bidirectional copy: PTY ↔ SSH channel. var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _, _ = io.Copy(ptmx, ch) }() go func() { defer wg.Done() _, _ = io.Copy(ch, ptmx) }() err = cmd.Wait() wg.Wait() exitCode := 0 if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { exitCode = exitErr.ExitCode() } else { exitCode = 1 } } sendExitStatus(ch, uint32(exitCode)) if exitCode != 0 { return "error" } return "ok" } // sendExitStatus sends an SSH exit-status request to the channel. func sendExitStatus(ch ssh.Channel, code uint32) { payload := []byte{byte(code >> 24), byte(code >> 16), byte(code >> 8), byte(code)} _, _ = ch.SendRequest("exit-status", false, payload) }