singularity-forge/rust-engine/crates/engine/src/symbol.rs
Mikael Hugo 5f52680285 chore: snapshot in-flight work (mcp graph refactor, native edit module, misc)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-02 08:31:44 +02:00

451 lines
16 KiB
Rust

//! Symbol-level structural replace and insert.
//!
//! Exposes [`replace_symbol`] and [`insert_around_symbol`] to JavaScript via
//! napi-rs. Both functions use the same ast-grep infrastructure as
//! `forge_ast::ast` (tree-sitter pattern matching) but add a higher-level
//! "find me the declaration named X" abstraction on top.
//!
//! ## Language support (v1)
//!
//! Only **TypeScript / JavaScript / TSX** are fully supported. For all other
//! languages the functions return an `Err` asking the caller to fall back to
//! `astEdit` with a custom pattern.
//!
//! ## Replacement scope (v1)
//!
//! For simplicity `replaceSymbol` replaces the **entire matched declaration**
//! (function / arrow / method node), not just the body. `new_body` is
//! therefore expected to be the full declaration text, e.g.:
//!
//! ```text
//! function foo(x: number): number { return x + 1; }
//! ```
//!
//! `insertAroundSymbol` supports only `BeforeDecl` and `AfterDecl` in v1.
//! `AtBodyStart` / `AtBodyEnd` return `Err("not yet implemented")`.
use ast_grep_core::{matcher::Pattern, tree_sitter::LanguageExt, Language};
use forge_ast::language::SupportLang;
use napi::{Error, Result};
use napi_derive::napi;
use std::{
fs,
path::{Path, PathBuf},
};
// ─── napi types ──────────────────────────────────────────────────────────────
#[napi(object)]
pub struct ReplaceSymbolOptions {
/// Force a specific language ("typescript", "rust", "python", …).
/// If absent, inferred from the file extension.
pub lang: Option<String>,
/// fsync the write. Defaults to true.
pub fsync: Option<bool>,
}
#[napi(object)]
pub struct ReplaceSymbolResult {
pub matched: bool,
/// Byte offset of the start of the replaced range (set only when matched).
#[napi(js_name = "byteStart")]
pub byte_start: Option<u32>,
/// Byte offset of the end of the replaced range (set only when matched).
#[napi(js_name = "byteEnd")]
pub byte_end: Option<u32>,
/// 1-based line number of the replacement start (set only when matched).
#[napi(js_name = "startLine")]
pub start_line: Option<u32>,
}
#[napi(string_enum)]
pub enum InsertPosition {
BeforeDecl,
AfterDecl,
AtBodyStart,
AtBodyEnd,
}
#[napi(object)]
pub struct InsertAroundSymbolOptions {
pub lang: Option<String>,
pub fsync: Option<bool>,
}
#[napi(object)]
pub struct InsertAroundSymbolResult {
pub inserted: bool,
/// Byte offset at which the code was inserted (set only when inserted).
#[napi(js_name = "byteOffset")]
pub byte_offset: Option<u32>,
}
// ─── language detection (self-contained, no phf dependency) ──────────────────
/// Resolve a user-supplied language name string to a `SupportLang`.
/// Covers the same aliases as `forge_ast::ast::LANG_ALIASES` but implemented
/// as a simple match to avoid a direct `phf` dependency in this crate.
fn resolve_lang_from_str(value: &str) -> Result<SupportLang> {
let l = value.to_ascii_lowercase();
let lang = match l.as_str() {
"bash" | "sh" => SupportLang::Bash,
"c" => SupportLang::C,
"cpp" | "c++" | "cc" | "cxx" => SupportLang::Cpp,
"csharp" | "c#" | "cs" => SupportLang::CSharp,
"css" => SupportLang::Css,
"diff" | "patch" => SupportLang::Diff,
"elixir" | "ex" => SupportLang::Elixir,
"go" | "golang" => SupportLang::Go,
"haskell" | "hs" => SupportLang::Haskell,
"hcl" | "tf" | "tfvars" | "terraform" => SupportLang::Hcl,
"html" | "htm" => SupportLang::Html,
"java" => SupportLang::Java,
"javascript" | "js" | "jsx" | "mjs" | "cjs" => SupportLang::JavaScript,
"json" => SupportLang::Json,
"julia" | "jl" => SupportLang::Julia,
"kotlin" | "kt" => SupportLang::Kotlin,
"lua" => SupportLang::Lua,
"make" | "makefile" => SupportLang::Make,
"markdown" | "md" | "mdx" => SupportLang::Markdown,
"nix" => SupportLang::Nix,
"objc" | "objective-c" => SupportLang::ObjC,
"odin" => SupportLang::Odin,
"php" => SupportLang::Php,
"python" | "py" => SupportLang::Python,
"regex" => SupportLang::Regex,
"ruby" | "rb" => SupportLang::Ruby,
"rust" | "rs" => SupportLang::Rust,
"scala" => SupportLang::Scala,
"solidity" | "sol" => SupportLang::Solidity,
"starlark" | "star" => SupportLang::Starlark,
"swift" => SupportLang::Swift,
"toml" => SupportLang::Toml,
"tsx" => SupportLang::Tsx,
"typescript" | "ts" | "mts" | "cts" => SupportLang::TypeScript,
"verilog" | "systemverilog" | "sv" => SupportLang::Verilog,
"xml" | "xsl" | "svg" => SupportLang::Xml,
"yaml" | "yml" => SupportLang::Yaml,
"zig" => SupportLang::Zig,
_ => {
return Err(Error::from_reason(format!(
"Unsupported language '{value}'"
)))
}
};
Ok(lang)
}
fn resolve_lang(lang_opt: Option<&str>, file_path: &Path) -> Result<SupportLang> {
if let Some(lang) = lang_opt.map(str::trim).filter(|l| !l.is_empty()) {
return resolve_lang_from_str(lang);
}
// Use the SupportLang trait impl which calls from_extension internally.
<SupportLang as Language>::from_path(file_path).ok_or_else(|| {
Error::from_reason(format!(
"Cannot infer language from '{}'. Specify `lang` explicitly.",
file_path.display()
))
})
}
// ─── language family check ────────────────────────────────────────────────────
/// Returns `true` for the TypeScript/JavaScript/TSX family.
fn is_ts_js(lang: SupportLang) -> bool {
matches!(
lang,
SupportLang::TypeScript | SupportLang::JavaScript | SupportLang::Tsx
)
}
// ─── pattern building ─────────────────────────────────────────────────────────
/// Build ast-grep patterns to try for a given symbol name in a TS/JS/TSX file.
///
/// Plain name → function declaration + arrow patterns.
/// Dotted name like `"Class.method"` → class method pattern.
fn ts_patterns_for_symbol(symbol: &str) -> Result<Vec<String>> {
if symbol.contains('.') {
let parts: Vec<&str> = symbol.splitn(2, '.').collect();
let class_name = parts[0].trim();
let method_name = parts[1].trim();
if class_name.is_empty() || method_name.is_empty() {
return Err(Error::from_reason(format!(
"Invalid symbol name '{symbol}': expected 'ClassName.methodName'"
)));
}
Ok(vec![
// Method inside a named class
format!("class {class_name} {{ $$$ {method_name}($$$ARGS) {{ $$$BODY }} $$$ }}"),
])
} else {
Ok(vec![
// function declaration
format!("function {symbol}($$$ARGS) {{ $$$BODY }}"),
// arrow with parens
format!("const {symbol} = ($$$ARGS) => {{ $$$BODY }}"),
// arrow without parens (single param)
format!("const {symbol} = $ARG => {{ $$$BODY }}"),
])
}
}
// ─── matching helpers ─────────────────────────────────────────────────────────
struct SymbolMatch {
byte_start: usize,
byte_end: usize,
start_line: usize, // 0-based (ast-grep convention)
}
/// Run all patterns against `source` and collect distinct top-level matches,
/// deduped by start byte. Returns an error if more than one distinct
/// declaration was found (ambiguity).
fn find_symbol_matches(
source: &str,
patterns: &[String],
lang: SupportLang,
) -> Result<Vec<SymbolMatch>> {
let mut compiled: Vec<Pattern> = Vec::new();
for pat_str in patterns {
match Pattern::try_new(pat_str, lang) {
Ok(p) => compiled.push(p),
Err(_) => {} // skip patterns that don't compile for this lang variant
}
}
if compiled.is_empty() {
return Err(Error::from_reason(
"No patterns compiled successfully for this symbol/language combination".to_string(),
));
}
let ast = lang.ast_grep(source);
// BTreeMap keyed on start byte → deduplicates when multiple patterns hit
// the same node.
let mut by_start: std::collections::BTreeMap<usize, SymbolMatch> =
std::collections::BTreeMap::new();
for pattern in compiled {
for m in ast.root().find_all(pattern) {
let range = m.range();
by_start.entry(range.start).or_insert(SymbolMatch {
byte_start: range.start,
byte_end: range.end,
start_line: m.start_pos().line(),
});
}
}
Ok(by_start.into_values().collect())
}
// ─── atomic write ────────────────────────────────────────────────────────────
fn atomic_write_bytes(path: &Path, content: &[u8], do_fsync: bool) -> std::io::Result<()> {
use std::io::Write;
let parent = path.parent().ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("path has no parent: {}", path.display()),
)
})?;
let file_name = path.file_name().and_then(|s| s.to_str()).ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("path has no filename: {}", path.display()),
)
})?;
let tmp_name = format!(".{file_name}.symbol.{}", std::process::id());
let tmp_path = parent.join(tmp_name);
{
let mut f = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&tmp_path)?;
f.write_all(content)?;
if do_fsync {
f.sync_all()?;
}
}
let rename_result = fs::rename(&tmp_path, path);
if rename_result.is_err() {
let _ = fs::remove_file(&tmp_path);
return rename_result;
}
if do_fsync {
if let Ok(dir_fd) = std::fs::File::open(parent) {
let _ = dir_fd.sync_all();
}
}
Ok(())
}
// ─── public napi functions ───────────────────────────────────────────────────
/// Replace the entire declaration of the symbol identified by `symbol_name`
/// with `new_body`.
///
/// `symbol_name` is either a plain identifier (e.g. `"executeCommand"`) or a
/// dotted path (e.g. `"MyClass.myMethod"`).
///
/// **v1 scope**: only TypeScript / JavaScript / TSX are supported. For other
/// languages use `astEdit` with a custom pattern.
///
/// **v1 replacement**: the *entire* matched declaration node is replaced, not
/// just its body. `new_body` should be the complete declaration text.
///
/// Returns `matched: false` when no declaration matches. Returns an error
/// when multiple distinct declarations match (ambiguity).
#[napi(js_name = "replaceSymbol")]
pub fn replace_symbol(
file_path: String,
symbol_name: String,
new_body: String,
options: Option<ReplaceSymbolOptions>,
) -> Result<ReplaceSymbolResult> {
let opts = options.unwrap_or(ReplaceSymbolOptions {
lang: None,
fsync: None,
});
let do_fsync = opts.fsync.unwrap_or(true);
let path = PathBuf::from(&file_path);
let lang = resolve_lang(opts.lang.as_deref(), &path)?;
if !is_ts_js(lang) {
return Err(Error::from_reason(format!(
"Language '{}' is not yet supported for symbol resolution. \
Use astEdit with a custom pattern instead.",
lang.canonical_name()
)));
}
let source = fs::read_to_string(&path)
.map_err(|e| Error::from_reason(format!("read {file_path}: {e}")))?;
let patterns = ts_patterns_for_symbol(&symbol_name)?;
let matches = find_symbol_matches(&source, &patterns, lang)?;
match matches.len() {
0 => Ok(ReplaceSymbolResult {
matched: false,
byte_start: None,
byte_end: None,
start_line: None,
}),
1 => {
let m = &matches[0];
let before = &source.as_bytes()[..m.byte_start];
let after = &source.as_bytes()[m.byte_end..];
let mut out = Vec::with_capacity(before.len() + new_body.len() + after.len());
out.extend_from_slice(before);
out.extend_from_slice(new_body.as_bytes());
out.extend_from_slice(after);
atomic_write_bytes(&path, &out, do_fsync)
.map_err(|e| Error::from_reason(format!("write {file_path}: {e}")))?;
Ok(ReplaceSymbolResult {
matched: true,
byte_start: Some(m.byte_start as u32),
byte_end: Some(m.byte_end as u32),
start_line: Some((m.start_line + 1) as u32),
})
}
n => Err(Error::from_reason(format!(
"Ambiguous symbol '{symbol_name}': found {n} matching declarations in '{file_path}'. \
Qualify the name (e.g. 'ClassName.methodName') or use astEdit with a narrower pattern."
))),
}
}
/// Insert `code` before or after the declaration of the symbol identified by
/// `symbol_name`.
///
/// **v1 scope**: only TypeScript / JavaScript / TSX are supported.
///
/// **v1 positions**: only `BeforeDecl` and `AfterDecl` are implemented.
/// `AtBodyStart` / `AtBodyEnd` return `Err("not yet implemented")`.
#[napi(js_name = "insertAroundSymbol")]
pub fn insert_around_symbol(
file_path: String,
symbol_name: String,
position: InsertPosition,
code: String,
options: Option<InsertAroundSymbolOptions>,
) -> Result<InsertAroundSymbolResult> {
match position {
InsertPosition::AtBodyStart | InsertPosition::AtBodyEnd => {
return Err(Error::from_reason(
"AtBodyStart / AtBodyEnd are not yet implemented in v1. \
Use BeforeDecl or AfterDecl, or use astEdit with a custom pattern."
.to_string(),
));
}
_ => {}
}
let opts = options.unwrap_or(InsertAroundSymbolOptions {
lang: None,
fsync: None,
});
let do_fsync = opts.fsync.unwrap_or(true);
let path = PathBuf::from(&file_path);
let lang = resolve_lang(opts.lang.as_deref(), &path)?;
if !is_ts_js(lang) {
return Err(Error::from_reason(format!(
"Language '{}' is not yet supported for symbol resolution. \
Use astEdit with a custom pattern instead.",
lang.canonical_name()
)));
}
let source = fs::read_to_string(&path)
.map_err(|e| Error::from_reason(format!("read {file_path}: {e}")))?;
let patterns = ts_patterns_for_symbol(&symbol_name)?;
let matches = find_symbol_matches(&source, &patterns, lang)?;
match matches.len() {
0 => Ok(InsertAroundSymbolResult {
inserted: false,
byte_offset: None,
}),
1 => {
let m = &matches[0];
let insert_at = match position {
InsertPosition::BeforeDecl => m.byte_start,
InsertPosition::AfterDecl => m.byte_end,
_ => unreachable!(),
};
let before = &source.as_bytes()[..insert_at];
let after = &source.as_bytes()[insert_at..];
let mut out = Vec::with_capacity(before.len() + code.len() + after.len());
out.extend_from_slice(before);
out.extend_from_slice(code.as_bytes());
out.extend_from_slice(after);
atomic_write_bytes(&path, &out, do_fsync)
.map_err(|e| Error::from_reason(format!("write {file_path}: {e}")))?;
Ok(InsertAroundSymbolResult {
inserted: true,
byte_offset: Some(insert_at as u32),
})
}
n => Err(Error::from_reason(format!(
"Ambiguous symbol '{symbol_name}': found {n} matching declarations in '{file_path}'. \
Qualify the name (e.g. 'ClassName.methodName') or use astEdit with a narrower pattern."
))),
}
}