200 lines
5.6 KiB
Go
200 lines
5.6 KiB
Go
|
|
// session.go — per-SSH-session PTY execution of sf headless.
|
||
|
|
//
|
||
|
|
// Purpose: spawn `sf headless <args>` in a real PTY for each authorized SSH session,
|
||
|
|
// wire the PTY I/O to the SSH channel, handle window-resize requests, and clean up
|
||
|
|
// on disconnect.
|
||
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"os"
|
||
|
|
"os/exec"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
|
||
|
|
"github.com/creack/pty"
|
||
|
|
"golang.org/x/crypto/ssh"
|
||
|
|
)
|
||
|
|
|
||
|
|
// allowedSubcommands are the sf subcommands a worker session may execute.
|
||
|
|
// Restricting to headless prevents the orchestrator key from being used for arbitrary execution.
|
||
|
|
var allowedSubcommands = map[string]bool{
|
||
|
|
"headless": true,
|
||
|
|
"version": true,
|
||
|
|
"--version": true,
|
||
|
|
}
|
||
|
|
|
||
|
|
// runSession handles a single SSH session channel: waits for an exec request,
|
||
|
|
// validates the command, spawns it in a PTY, and streams I/O until completion.
|
||
|
|
// Returns an outcome label for metrics: "ok", "rejected", "error".
|
||
|
|
func (s *Server) runSession(ctx context.Context, ch ssh.Channel, requests <-chan *ssh.Request, remote, fp string) string {
|
||
|
|
defer ch.Close()
|
||
|
|
|
||
|
|
type execReq struct {
|
||
|
|
cmd string
|
||
|
|
ptyW uint32
|
||
|
|
ptyH uint32
|
||
|
|
hasPTY bool
|
||
|
|
}
|
||
|
|
|
||
|
|
var pending execReq
|
||
|
|
|
||
|
|
// Collect session requests until we get "exec" or the channel closes.
|
||
|
|
for req := range requests {
|
||
|
|
switch req.Type {
|
||
|
|
case "pty-req":
|
||
|
|
// https://datatracker.ietf.org/doc/html/rfc4254#section-6.2
|
||
|
|
if len(req.Payload) < 4 {
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
// term string length prefix
|
||
|
|
termLen := int(req.Payload[0])<<24 | int(req.Payload[1])<<16 | int(req.Payload[2])<<8 | int(req.Payload[3])
|
||
|
|
offset := 4 + termLen
|
||
|
|
if len(req.Payload) < offset+8 {
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
pending.ptyW = uint32(req.Payload[offset])<<24 | uint32(req.Payload[offset+1])<<16 | uint32(req.Payload[offset+2])<<8 | uint32(req.Payload[offset+3])
|
||
|
|
pending.ptyH = uint32(req.Payload[offset+4])<<24 | uint32(req.Payload[offset+5])<<16 | uint32(req.Payload[offset+6])<<8 | uint32(req.Payload[offset+7])
|
||
|
|
pending.hasPTY = true
|
||
|
|
_ = req.Reply(true, nil)
|
||
|
|
|
||
|
|
case "window-change":
|
||
|
|
// Handled after PTY is started (ignore if no PTY yet).
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
|
||
|
|
case "exec":
|
||
|
|
if len(req.Payload) < 4 {
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
return "rejected"
|
||
|
|
}
|
||
|
|
cmdLen := int(req.Payload[0])<<24 | int(req.Payload[1])<<16 | int(req.Payload[2])<<8 | int(req.Payload[3])
|
||
|
|
if len(req.Payload) < 4+cmdLen {
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
return "rejected"
|
||
|
|
}
|
||
|
|
pending.cmd = string(req.Payload[4 : 4+cmdLen])
|
||
|
|
_ = req.Reply(true, nil)
|
||
|
|
|
||
|
|
outcome := s.execCommand(ctx, ch, requests, pending.cmd, pending.ptyW, pending.ptyH, remote, fp)
|
||
|
|
return outcome
|
||
|
|
|
||
|
|
case "shell":
|
||
|
|
// No interactive shell — reject.
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
sendExitStatus(ch, 1)
|
||
|
|
return "rejected"
|
||
|
|
|
||
|
|
default:
|
||
|
|
if req.WantReply {
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return "ok"
|
||
|
|
}
|
||
|
|
|
||
|
|
// execCommand validates and executes the requested command in a PTY.
|
||
|
|
func (s *Server) execCommand(ctx context.Context, ch ssh.Channel, requests <-chan *ssh.Request, cmdStr string, ptyW, ptyH uint32, remote, fp string) string {
|
||
|
|
args := strings.Fields(cmdStr)
|
||
|
|
if len(args) == 0 {
|
||
|
|
sendExitStatus(ch, 1)
|
||
|
|
return "rejected"
|
||
|
|
}
|
||
|
|
|
||
|
|
// Validate: first arg must be "sf" (or the sf-bin basename), second must be an allowed subcommand.
|
||
|
|
sfBase := s.sfBin
|
||
|
|
for i := len(sfBase) - 1; i >= 0; i-- {
|
||
|
|
if sfBase[i] == '/' || sfBase[i] == '\\' {
|
||
|
|
sfBase = sfBase[i+1:]
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
start := 0
|
||
|
|
if args[0] == sfBase || args[0] == "sf" {
|
||
|
|
start = 1
|
||
|
|
}
|
||
|
|
if start >= len(args) || !allowedSubcommands[args[start]] {
|
||
|
|
s.logger.Warn("rejected command", "remote", remote, "cmd", cmdStr)
|
||
|
|
fmt.Fprintf(ch, "sf-worker: command not allowed: %q\r\n", cmdStr)
|
||
|
|
sendExitStatus(ch, 1)
|
||
|
|
return "rejected"
|
||
|
|
}
|
||
|
|
|
||
|
|
// Build the actual command: replace the leading "sf" with the real binary path.
|
||
|
|
execArgs := append([]string{s.sfBin}, args[start:]...)
|
||
|
|
cmd := exec.CommandContext(ctx, execArgs[0], execArgs[1:]...)
|
||
|
|
cmd.Env = append(os.Environ(), "SF_WORKER=1", fmt.Sprintf("SF_WORKER_CLIENT_FP=%s", fp))
|
||
|
|
|
||
|
|
s.logger.Info("exec", "remote", remote, "cmd", execArgs)
|
||
|
|
|
||
|
|
// Start with PTY.
|
||
|
|
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{
|
||
|
|
Cols: uint16(ptyW),
|
||
|
|
Rows: uint16(ptyH),
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
s.logger.Error("pty start failed", "err", err)
|
||
|
|
fmt.Fprintf(ch, "sf-worker: failed to start: %v\r\n", err)
|
||
|
|
sendExitStatus(ch, 1)
|
||
|
|
return "error"
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = ptmx.Close()
|
||
|
|
}()
|
||
|
|
|
||
|
|
// Handle subsequent window-change requests in background.
|
||
|
|
go func() {
|
||
|
|
for req := range requests {
|
||
|
|
if req.Type == "window-change" && len(req.Payload) >= 8 {
|
||
|
|
w := uint32(req.Payload[0])<<24 | uint32(req.Payload[1])<<16 | uint32(req.Payload[2])<<8 | uint32(req.Payload[3])
|
||
|
|
h := uint32(req.Payload[4])<<24 | uint32(req.Payload[5])<<16 | uint32(req.Payload[6])<<8 | uint32(req.Payload[7])
|
||
|
|
_ = pty.Setsize(ptmx, &pty.Winsize{Cols: uint16(w), Rows: uint16(h)})
|
||
|
|
}
|
||
|
|
if req.WantReply {
|
||
|
|
_ = req.Reply(false, nil)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
// Bidirectional copy: PTY ↔ SSH channel.
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
wg.Add(2)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
_, _ = io.Copy(ptmx, ch)
|
||
|
|
}()
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
_, _ = io.Copy(ch, ptmx)
|
||
|
|
}()
|
||
|
|
|
||
|
|
err = cmd.Wait()
|
||
|
|
wg.Wait()
|
||
|
|
|
||
|
|
exitCode := 0
|
||
|
|
if err != nil {
|
||
|
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||
|
|
exitCode = exitErr.ExitCode()
|
||
|
|
} else {
|
||
|
|
exitCode = 1
|
||
|
|
}
|
||
|
|
}
|
||
|
|
sendExitStatus(ch, uint32(exitCode))
|
||
|
|
|
||
|
|
if exitCode != 0 {
|
||
|
|
return "error"
|
||
|
|
}
|
||
|
|
return "ok"
|
||
|
|
}
|
||
|
|
|
||
|
|
// sendExitStatus sends an SSH exit-status request to the channel.
|
||
|
|
func sendExitStatus(ch ssh.Channel, code uint32) {
|
||
|
|
payload := []byte{byte(code >> 24), byte(code >> 16), byte(code >> 8), byte(code)}
|
||
|
|
_, _ = ch.SendRequest("exit-status", false, payload)
|
||
|
|
}
|