451 lines
16 KiB
Rust
451 lines
16 KiB
Rust
//! Symbol-level structural replace and insert.
|
|
//!
|
|
//! Exposes [`replace_symbol`] and [`insert_around_symbol`] to JavaScript via
|
|
//! napi-rs. Both functions use the same ast-grep infrastructure as
|
|
//! `forge_ast::ast` (tree-sitter pattern matching) but add a higher-level
|
|
//! "find me the declaration named X" abstraction on top.
|
|
//!
|
|
//! ## Language support (v1)
|
|
//!
|
|
//! Only **TypeScript / JavaScript / TSX** are fully supported. For all other
|
|
//! languages the functions return an `Err` asking the caller to fall back to
|
|
//! `astEdit` with a custom pattern.
|
|
//!
|
|
//! ## Replacement scope (v1)
|
|
//!
|
|
//! For simplicity `replaceSymbol` replaces the **entire matched declaration**
|
|
//! (function / arrow / method node), not just the body. `new_body` is
|
|
//! therefore expected to be the full declaration text, e.g.:
|
|
//!
|
|
//! ```text
|
|
//! function foo(x: number): number { return x + 1; }
|
|
//! ```
|
|
//!
|
|
//! `insertAroundSymbol` supports only `BeforeDecl` and `AfterDecl` in v1.
|
|
//! `AtBodyStart` / `AtBodyEnd` return `Err("not yet implemented")`.
|
|
|
|
use ast_grep_core::{matcher::Pattern, tree_sitter::LanguageExt, Language};
|
|
use forge_ast::language::SupportLang;
|
|
use napi::{Error, Result};
|
|
use napi_derive::napi;
|
|
use std::{
|
|
fs,
|
|
path::{Path, PathBuf},
|
|
};
|
|
|
|
// ─── napi types ──────────────────────────────────────────────────────────────
|
|
|
|
#[napi(object)]
|
|
pub struct ReplaceSymbolOptions {
|
|
/// Force a specific language ("typescript", "rust", "python", …).
|
|
/// If absent, inferred from the file extension.
|
|
pub lang: Option<String>,
|
|
/// fsync the write. Defaults to true.
|
|
pub fsync: Option<bool>,
|
|
}
|
|
|
|
#[napi(object)]
|
|
pub struct ReplaceSymbolResult {
|
|
pub matched: bool,
|
|
/// Byte offset of the start of the replaced range (set only when matched).
|
|
#[napi(js_name = "byteStart")]
|
|
pub byte_start: Option<u32>,
|
|
/// Byte offset of the end of the replaced range (set only when matched).
|
|
#[napi(js_name = "byteEnd")]
|
|
pub byte_end: Option<u32>,
|
|
/// 1-based line number of the replacement start (set only when matched).
|
|
#[napi(js_name = "startLine")]
|
|
pub start_line: Option<u32>,
|
|
}
|
|
|
|
#[napi(string_enum)]
|
|
pub enum InsertPosition {
|
|
BeforeDecl,
|
|
AfterDecl,
|
|
AtBodyStart,
|
|
AtBodyEnd,
|
|
}
|
|
|
|
#[napi(object)]
|
|
pub struct InsertAroundSymbolOptions {
|
|
pub lang: Option<String>,
|
|
pub fsync: Option<bool>,
|
|
}
|
|
|
|
#[napi(object)]
|
|
pub struct InsertAroundSymbolResult {
|
|
pub inserted: bool,
|
|
/// Byte offset at which the code was inserted (set only when inserted).
|
|
#[napi(js_name = "byteOffset")]
|
|
pub byte_offset: Option<u32>,
|
|
}
|
|
|
|
// ─── language detection (self-contained, no phf dependency) ──────────────────
|
|
|
|
/// Resolve a user-supplied language name string to a `SupportLang`.
|
|
/// Covers the same aliases as `forge_ast::ast::LANG_ALIASES` but implemented
|
|
/// as a simple match to avoid a direct `phf` dependency in this crate.
|
|
fn resolve_lang_from_str(value: &str) -> Result<SupportLang> {
|
|
let l = value.to_ascii_lowercase();
|
|
let lang = match l.as_str() {
|
|
"bash" | "sh" => SupportLang::Bash,
|
|
"c" => SupportLang::C,
|
|
"cpp" | "c++" | "cc" | "cxx" => SupportLang::Cpp,
|
|
"csharp" | "c#" | "cs" => SupportLang::CSharp,
|
|
"css" => SupportLang::Css,
|
|
"diff" | "patch" => SupportLang::Diff,
|
|
"elixir" | "ex" => SupportLang::Elixir,
|
|
"go" | "golang" => SupportLang::Go,
|
|
"haskell" | "hs" => SupportLang::Haskell,
|
|
"hcl" | "tf" | "tfvars" | "terraform" => SupportLang::Hcl,
|
|
"html" | "htm" => SupportLang::Html,
|
|
"java" => SupportLang::Java,
|
|
"javascript" | "js" | "jsx" | "mjs" | "cjs" => SupportLang::JavaScript,
|
|
"json" => SupportLang::Json,
|
|
"julia" | "jl" => SupportLang::Julia,
|
|
"kotlin" | "kt" => SupportLang::Kotlin,
|
|
"lua" => SupportLang::Lua,
|
|
"make" | "makefile" => SupportLang::Make,
|
|
"markdown" | "md" | "mdx" => SupportLang::Markdown,
|
|
"nix" => SupportLang::Nix,
|
|
"objc" | "objective-c" => SupportLang::ObjC,
|
|
"odin" => SupportLang::Odin,
|
|
"php" => SupportLang::Php,
|
|
"python" | "py" => SupportLang::Python,
|
|
"regex" => SupportLang::Regex,
|
|
"ruby" | "rb" => SupportLang::Ruby,
|
|
"rust" | "rs" => SupportLang::Rust,
|
|
"scala" => SupportLang::Scala,
|
|
"solidity" | "sol" => SupportLang::Solidity,
|
|
"starlark" | "star" => SupportLang::Starlark,
|
|
"swift" => SupportLang::Swift,
|
|
"toml" => SupportLang::Toml,
|
|
"tsx" => SupportLang::Tsx,
|
|
"typescript" | "ts" | "mts" | "cts" => SupportLang::TypeScript,
|
|
"verilog" | "systemverilog" | "sv" => SupportLang::Verilog,
|
|
"xml" | "xsl" | "svg" => SupportLang::Xml,
|
|
"yaml" | "yml" => SupportLang::Yaml,
|
|
"zig" => SupportLang::Zig,
|
|
_ => {
|
|
return Err(Error::from_reason(format!(
|
|
"Unsupported language '{value}'"
|
|
)))
|
|
}
|
|
};
|
|
Ok(lang)
|
|
}
|
|
|
|
fn resolve_lang(lang_opt: Option<&str>, file_path: &Path) -> Result<SupportLang> {
|
|
if let Some(lang) = lang_opt.map(str::trim).filter(|l| !l.is_empty()) {
|
|
return resolve_lang_from_str(lang);
|
|
}
|
|
// Use the SupportLang trait impl which calls from_extension internally.
|
|
<SupportLang as Language>::from_path(file_path).ok_or_else(|| {
|
|
Error::from_reason(format!(
|
|
"Cannot infer language from '{}'. Specify `lang` explicitly.",
|
|
file_path.display()
|
|
))
|
|
})
|
|
}
|
|
|
|
// ─── language family check ────────────────────────────────────────────────────
|
|
|
|
/// Returns `true` for the TypeScript/JavaScript/TSX family.
|
|
fn is_ts_js(lang: SupportLang) -> bool {
|
|
matches!(
|
|
lang,
|
|
SupportLang::TypeScript | SupportLang::JavaScript | SupportLang::Tsx
|
|
)
|
|
}
|
|
|
|
// ─── pattern building ─────────────────────────────────────────────────────────
|
|
|
|
/// Build ast-grep patterns to try for a given symbol name in a TS/JS/TSX file.
|
|
///
|
|
/// Plain name → function declaration + arrow patterns.
|
|
/// Dotted name like `"Class.method"` → class method pattern.
|
|
fn ts_patterns_for_symbol(symbol: &str) -> Result<Vec<String>> {
|
|
if symbol.contains('.') {
|
|
let parts: Vec<&str> = symbol.splitn(2, '.').collect();
|
|
let class_name = parts[0].trim();
|
|
let method_name = parts[1].trim();
|
|
if class_name.is_empty() || method_name.is_empty() {
|
|
return Err(Error::from_reason(format!(
|
|
"Invalid symbol name '{symbol}': expected 'ClassName.methodName'"
|
|
)));
|
|
}
|
|
Ok(vec![
|
|
// Method inside a named class
|
|
format!("class {class_name} {{ $$$ {method_name}($$$ARGS) {{ $$$BODY }} $$$ }}"),
|
|
])
|
|
} else {
|
|
Ok(vec![
|
|
// function declaration
|
|
format!("function {symbol}($$$ARGS) {{ $$$BODY }}"),
|
|
// arrow with parens
|
|
format!("const {symbol} = ($$$ARGS) => {{ $$$BODY }}"),
|
|
// arrow without parens (single param)
|
|
format!("const {symbol} = $ARG => {{ $$$BODY }}"),
|
|
])
|
|
}
|
|
}
|
|
|
|
// ─── matching helpers ─────────────────────────────────────────────────────────
|
|
|
|
struct SymbolMatch {
|
|
byte_start: usize,
|
|
byte_end: usize,
|
|
start_line: usize, // 0-based (ast-grep convention)
|
|
}
|
|
|
|
/// Run all patterns against `source` and collect distinct top-level matches,
|
|
/// deduped by start byte. Returns an error if more than one distinct
|
|
/// declaration was found (ambiguity).
|
|
fn find_symbol_matches(
|
|
source: &str,
|
|
patterns: &[String],
|
|
lang: SupportLang,
|
|
) -> Result<Vec<SymbolMatch>> {
|
|
let mut compiled: Vec<Pattern> = Vec::new();
|
|
for pat_str in patterns {
|
|
match Pattern::try_new(pat_str, lang) {
|
|
Ok(p) => compiled.push(p),
|
|
Err(_) => {} // skip patterns that don't compile for this lang variant
|
|
}
|
|
}
|
|
if compiled.is_empty() {
|
|
return Err(Error::from_reason(
|
|
"No patterns compiled successfully for this symbol/language combination".to_string(),
|
|
));
|
|
}
|
|
|
|
let ast = lang.ast_grep(source);
|
|
// BTreeMap keyed on start byte → deduplicates when multiple patterns hit
|
|
// the same node.
|
|
let mut by_start: std::collections::BTreeMap<usize, SymbolMatch> =
|
|
std::collections::BTreeMap::new();
|
|
|
|
for pattern in compiled {
|
|
for m in ast.root().find_all(pattern) {
|
|
let range = m.range();
|
|
by_start.entry(range.start).or_insert(SymbolMatch {
|
|
byte_start: range.start,
|
|
byte_end: range.end,
|
|
start_line: m.start_pos().line(),
|
|
});
|
|
}
|
|
}
|
|
|
|
Ok(by_start.into_values().collect())
|
|
}
|
|
|
|
// ─── atomic write ────────────────────────────────────────────────────────────
|
|
|
|
fn atomic_write_bytes(path: &Path, content: &[u8], do_fsync: bool) -> std::io::Result<()> {
|
|
use std::io::Write;
|
|
|
|
let parent = path.parent().ok_or_else(|| {
|
|
std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
format!("path has no parent: {}", path.display()),
|
|
)
|
|
})?;
|
|
let file_name = path.file_name().and_then(|s| s.to_str()).ok_or_else(|| {
|
|
std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
format!("path has no filename: {}", path.display()),
|
|
)
|
|
})?;
|
|
|
|
let tmp_name = format!(".{file_name}.symbol.{}", std::process::id());
|
|
let tmp_path = parent.join(tmp_name);
|
|
|
|
{
|
|
let mut f = std::fs::OpenOptions::new()
|
|
.write(true)
|
|
.create(true)
|
|
.truncate(true)
|
|
.open(&tmp_path)?;
|
|
f.write_all(content)?;
|
|
if do_fsync {
|
|
f.sync_all()?;
|
|
}
|
|
}
|
|
|
|
let rename_result = fs::rename(&tmp_path, path);
|
|
if rename_result.is_err() {
|
|
let _ = fs::remove_file(&tmp_path);
|
|
return rename_result;
|
|
}
|
|
|
|
if do_fsync {
|
|
if let Ok(dir_fd) = std::fs::File::open(parent) {
|
|
let _ = dir_fd.sync_all();
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// ─── public napi functions ───────────────────────────────────────────────────
|
|
|
|
/// Replace the entire declaration of the symbol identified by `symbol_name`
|
|
/// with `new_body`.
|
|
///
|
|
/// `symbol_name` is either a plain identifier (e.g. `"executeCommand"`) or a
|
|
/// dotted path (e.g. `"MyClass.myMethod"`).
|
|
///
|
|
/// **v1 scope**: only TypeScript / JavaScript / TSX are supported. For other
|
|
/// languages use `astEdit` with a custom pattern.
|
|
///
|
|
/// **v1 replacement**: the *entire* matched declaration node is replaced, not
|
|
/// just its body. `new_body` should be the complete declaration text.
|
|
///
|
|
/// Returns `matched: false` when no declaration matches. Returns an error
|
|
/// when multiple distinct declarations match (ambiguity).
|
|
#[napi(js_name = "replaceSymbol")]
|
|
pub fn replace_symbol(
|
|
file_path: String,
|
|
symbol_name: String,
|
|
new_body: String,
|
|
options: Option<ReplaceSymbolOptions>,
|
|
) -> Result<ReplaceSymbolResult> {
|
|
let opts = options.unwrap_or(ReplaceSymbolOptions {
|
|
lang: None,
|
|
fsync: None,
|
|
});
|
|
let do_fsync = opts.fsync.unwrap_or(true);
|
|
|
|
let path = PathBuf::from(&file_path);
|
|
let lang = resolve_lang(opts.lang.as_deref(), &path)?;
|
|
|
|
if !is_ts_js(lang) {
|
|
return Err(Error::from_reason(format!(
|
|
"Language '{}' is not yet supported for symbol resolution. \
|
|
Use astEdit with a custom pattern instead.",
|
|
lang.canonical_name()
|
|
)));
|
|
}
|
|
|
|
let source = fs::read_to_string(&path)
|
|
.map_err(|e| Error::from_reason(format!("read {file_path}: {e}")))?;
|
|
|
|
let patterns = ts_patterns_for_symbol(&symbol_name)?;
|
|
let matches = find_symbol_matches(&source, &patterns, lang)?;
|
|
|
|
match matches.len() {
|
|
0 => Ok(ReplaceSymbolResult {
|
|
matched: false,
|
|
byte_start: None,
|
|
byte_end: None,
|
|
start_line: None,
|
|
}),
|
|
1 => {
|
|
let m = &matches[0];
|
|
let before = &source.as_bytes()[..m.byte_start];
|
|
let after = &source.as_bytes()[m.byte_end..];
|
|
let mut out = Vec::with_capacity(before.len() + new_body.len() + after.len());
|
|
out.extend_from_slice(before);
|
|
out.extend_from_slice(new_body.as_bytes());
|
|
out.extend_from_slice(after);
|
|
|
|
atomic_write_bytes(&path, &out, do_fsync)
|
|
.map_err(|e| Error::from_reason(format!("write {file_path}: {e}")))?;
|
|
|
|
Ok(ReplaceSymbolResult {
|
|
matched: true,
|
|
byte_start: Some(m.byte_start as u32),
|
|
byte_end: Some(m.byte_end as u32),
|
|
start_line: Some((m.start_line + 1) as u32),
|
|
})
|
|
}
|
|
n => Err(Error::from_reason(format!(
|
|
"Ambiguous symbol '{symbol_name}': found {n} matching declarations in '{file_path}'. \
|
|
Qualify the name (e.g. 'ClassName.methodName') or use astEdit with a narrower pattern."
|
|
))),
|
|
}
|
|
}
|
|
|
|
/// Insert `code` before or after the declaration of the symbol identified by
|
|
/// `symbol_name`.
|
|
///
|
|
/// **v1 scope**: only TypeScript / JavaScript / TSX are supported.
|
|
///
|
|
/// **v1 positions**: only `BeforeDecl` and `AfterDecl` are implemented.
|
|
/// `AtBodyStart` / `AtBodyEnd` return `Err("not yet implemented")`.
|
|
#[napi(js_name = "insertAroundSymbol")]
|
|
pub fn insert_around_symbol(
|
|
file_path: String,
|
|
symbol_name: String,
|
|
position: InsertPosition,
|
|
code: String,
|
|
options: Option<InsertAroundSymbolOptions>,
|
|
) -> Result<InsertAroundSymbolResult> {
|
|
match position {
|
|
InsertPosition::AtBodyStart | InsertPosition::AtBodyEnd => {
|
|
return Err(Error::from_reason(
|
|
"AtBodyStart / AtBodyEnd are not yet implemented in v1. \
|
|
Use BeforeDecl or AfterDecl, or use astEdit with a custom pattern."
|
|
.to_string(),
|
|
));
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
let opts = options.unwrap_or(InsertAroundSymbolOptions {
|
|
lang: None,
|
|
fsync: None,
|
|
});
|
|
let do_fsync = opts.fsync.unwrap_or(true);
|
|
|
|
let path = PathBuf::from(&file_path);
|
|
let lang = resolve_lang(opts.lang.as_deref(), &path)?;
|
|
|
|
if !is_ts_js(lang) {
|
|
return Err(Error::from_reason(format!(
|
|
"Language '{}' is not yet supported for symbol resolution. \
|
|
Use astEdit with a custom pattern instead.",
|
|
lang.canonical_name()
|
|
)));
|
|
}
|
|
|
|
let source = fs::read_to_string(&path)
|
|
.map_err(|e| Error::from_reason(format!("read {file_path}: {e}")))?;
|
|
|
|
let patterns = ts_patterns_for_symbol(&symbol_name)?;
|
|
let matches = find_symbol_matches(&source, &patterns, lang)?;
|
|
|
|
match matches.len() {
|
|
0 => Ok(InsertAroundSymbolResult {
|
|
inserted: false,
|
|
byte_offset: None,
|
|
}),
|
|
1 => {
|
|
let m = &matches[0];
|
|
let insert_at = match position {
|
|
InsertPosition::BeforeDecl => m.byte_start,
|
|
InsertPosition::AfterDecl => m.byte_end,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
let before = &source.as_bytes()[..insert_at];
|
|
let after = &source.as_bytes()[insert_at..];
|
|
let mut out = Vec::with_capacity(before.len() + code.len() + after.len());
|
|
out.extend_from_slice(before);
|
|
out.extend_from_slice(code.as_bytes());
|
|
out.extend_from_slice(after);
|
|
|
|
atomic_write_bytes(&path, &out, do_fsync)
|
|
.map_err(|e| Error::from_reason(format!("write {file_path}: {e}")))?;
|
|
|
|
Ok(InsertAroundSymbolResult {
|
|
inserted: true,
|
|
byte_offset: Some(insert_at as u32),
|
|
})
|
|
}
|
|
n => Err(Error::from_reason(format!(
|
|
"Ambiguous symbol '{symbol_name}': found {n} matching declarations in '{file_path}'. \
|
|
Qualify the name (e.g. 'ClassName.methodName') or use astEdit with a narrower pattern."
|
|
))),
|
|
}
|
|
}
|