diff --git a/Cargo.toml b/Cargo.toml index f04e04c6e..0b5072302 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ getset = "0.1.2" pyo3 = "0.19.0" pyo3-log = "0.8.1" glob = "0.3.1" +ast-grep-core = "0.7.2" [features] extension-module = ["pyo3/extension-module"] diff --git a/src/models/capture_group_patterns.rs b/src/models/capture_group_patterns.rs index 17309fb5a..28b16963d 100644 --- a/src/models/capture_group_patterns.rs +++ b/src/models/capture_group_patterns.rs @@ -14,18 +14,23 @@ Copyright (c) 2023 Uber Technologies, Inc. use crate::{ models::Validator, utilities::{ + ast_grep_utilities::get_all_matches_for_ast_grep_pattern, regex_utilities::get_all_matches_for_regex, tree_sitter_utilities::{get_all_matches_for_query, get_ts_query_parser, number_of_errors}, Instantiate, }, }; +use ast_grep_core::{language::TSLanguage, Pattern, StrDoc}; use pyo3::prelude::pyclass; use regex::Regex; use serde_derive::Deserialize; use std::collections::HashMap; use tree_sitter::{Node, Query}; -use super::{default_configs::REGEX_QUERY_PREFIX, matches::Match}; +use super::{ + default_configs::{AST_GREP_PREFIX, REGEX_QUERY_PREFIX}, + matches::Match, +}; #[pyclass] #[derive(Deserialize, Debug, Clone, Default, PartialEq, Hash, Eq)] @@ -54,6 +59,9 @@ impl Validator for CGPattern { .map(|_| Ok(())) .unwrap_or(Err(format!("Cannot parse the regex - {}", self.pattern()))); } + if self.pattern().starts_with(AST_GREP_PREFIX) { + return Ok(()); + } let mut parser = get_ts_query_parser(); parser .parse(self.pattern(), None) @@ -78,6 +86,7 @@ impl Instantiate for CGPattern { #[derive(Debug)] pub(crate) enum CompiledCGPattern { + P(Pattern>), Q(Query), R(Regex), } @@ -104,6 +113,7 @@ impl CompiledCGPattern { replace_node_idx: Option, ) -> Vec { match self { + CompiledCGPattern::P(_pattern) => panic!("ast-grep pattern is not supported"), CompiledCGPattern::Q(query) => get_all_matches_for_query( node, source_code, @@ -112,9 +122,7 @@ impl CompiledCGPattern { replace_node, replace_node_idx, ), - CompiledCGPattern::R(regex) => { - get_all_matches_for_regex(node, source_code, regex, recursive, replace_node) - } + CompiledCGPattern::R(regex) => get_all_matches_for_ast_grep_pattern(), } } } diff --git a/src/models/default_configs.rs b/src/models/default_configs.rs index 269136979..a5f5a3e04 100644 --- a/src/models/default_configs.rs +++ b/src/models/default_configs.rs @@ -32,6 +32,7 @@ pub const STRINGS: &str = "strings"; pub const TS_SCHEME: &str = "scm"; // We support scheme files that contain tree-sitter query pub const REGEX_QUERY_PREFIX: &str = "rgx "; +pub const AST_GREP_PREFIX: &str = "sg "; #[cfg(test)] //FIXME: Remove this hack by not passing PiranhaArguments to SourceCodeUnit diff --git a/src/models/matches.rs b/src/models/matches.rs index 193169dda..49ca22597 100644 --- a/src/models/matches.rs +++ b/src/models/matches.rs @@ -13,6 +13,7 @@ Copyright (c) 2023 Uber Technologies, Inc. use std::collections::HashMap; +use ast_grep_core::{language::TSLanguage, StrDoc}; use getset::{Getters, MutGetters}; use itertools::Itertools; use log::trace; @@ -67,6 +68,19 @@ impl Match { } } + pub(crate) fn from_ast_grep_captures( + captures: &ast_grep_core::Node<'_, StrDoc>, matches: HashMap, + source_code: &str, + ) -> Self { + Match { + matched_string: captures.text().to_string(), + range: Range::from_range(&captures.range(), source_code), + matches, + associated_comma: None, + associated_comments: Vec::new(), + } + } + pub(crate) fn new( matched_string: String, range: tree_sitter::Range, matches: HashMap, ) -> Self { @@ -281,6 +295,15 @@ impl Range { end_point: position_for_offset(source_code.as_bytes(), mtch.end()), } } + + pub(crate) fn from_range(range: &std::ops::Range, source_code: &str) -> Self { + Self { + start_byte: range.start, + end_byte: range.end, + start_point: position_for_offset(source_code.as_bytes(), range.start), + end_point: position_for_offset(source_code.as_bytes(), range.end), + } + } } // Finds the position (col and row number) for a given offset. diff --git a/src/utilities/ast_grep_utilities.rs b/src/utilities/ast_grep_utilities.rs new file mode 100644 index 000000000..226142fc1 --- /dev/null +++ b/src/utilities/ast_grep_utilities.rs @@ -0,0 +1,81 @@ +use std::{collections::HashMap, hash::Hash}; + +/* +Copyright (c) 2023 Uber Technologies, Inc. + +

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +except in compliance with the License. You may obtain a copy of the License at +

http://www.apache.org/licenses/LICENSE-2.0 + +

Unless required by applicable law or agreed to in writing, software distributed under the +License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +express or implied. See the License for the specific language governing permissions and +limitations under the License. +*/ +use ast_grep_core::{language::TSLanguage, AstGrep, Matcher, Pattern, StrDoc}; +use tree_sitter::Node; + +use crate::models::matches::Match; + +/// Applies the query upon the given `node`, and gets all the matches +/// # Arguments +/// * `node` - the root node to apply the query upon +/// * `source_code` - the corresponding source code string for the node. +/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`. +/// * `replace_node` - node to replace +/// +/// # Returns +/// List containing all the matches against `node` +pub(crate) fn get_all_matches_for_ast_grep_pattern( + node: &Node, source_code: String, pattern: &Pattern>, recursive: bool, + replace_node: Option, language: tree_sitter::Language, +) -> Vec { + let x = AstGrep::new(&source_code, TSLanguage::from(language)); + + let all_captures = x.root().find_all(pattern); + let mut all_matches = vec![]; + for captures in all_captures { + let range_matches_node = + node.start_byte() == captures.range().start && node.end_byte() == captures.range().end; + let range_matches_inside_node = + node.start_byte() <= captures.range().start && node.end_byte() >= captures.range().end; + if (recursive && range_matches_inside_node) || range_matches_node { + let replace_node_match = if let Some(ref rn) = replace_node { + captures + .get_env() + .get_match(rn) + .unwrap_or_else(|| panic!("The tag {rn} provided in the replace node is not present")) + } else { + captures.get_node() + }; + let matches = extract_captures(&captures); + all_matches.push(Match::from_ast_grep_captures( + &replace_node_match, + matches, + &source_code, + )); + } + } + vec![] +} + +fn extract_captures( + captures: &ast_grep_core::NodeMatch<'_, StrDoc>, +) -> HashMap { + let mut map = HashMap::new(); + for v in captures.get_env().get_matched_variables() { + let name = match v { + ast_grep_core::meta_var::MetaVariable::Named(name, _) => Some(name), + ast_grep_core::meta_var::MetaVariable::Anonymous(_) => None, + ast_grep_core::meta_var::MetaVariable::Ellipsis => None, + ast_grep_core::meta_var::MetaVariable::NamedEllipsis(name) => Some(name), + }; + if let Some(n) = name { + map.insert( + n.to_string(), + captures.get_env().get_match(&n).unwrap().text().to_string(), + ); + } + } + return map; +} diff --git a/src/utilities/mod.rs b/src/utilities/mod.rs index 7ca6df029..d6bde309b 100644 --- a/src/utilities/mod.rs +++ b/src/utilities/mod.rs @@ -11,6 +11,7 @@ Copyright (c) 2023 Uber Technologies, Inc. limitations under the License. */ +pub(crate) mod ast_grep_utilities; pub(crate) mod regex_utilities; pub(crate) mod tree_sitter_utilities; use std::collections::HashMap; diff --git a/utilities/ast_grep_utilities.rs b/utilities/ast_grep_utilities.rs new file mode 100644 index 000000000..c46e40f0e --- /dev/null +++ b/utilities/ast_grep_utilities.rs @@ -0,0 +1,13 @@ +/* + Copyright (c) 2023 Uber Technologies, Inc. + +

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file + except in compliance with the License. You may obtain a copy of the License at +

http://www.apache.org/licenses/LICENSE-2.0 + +

Unless required by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + express or implied. See the License for the specific language governing permissions and + limitations under the License. + */ +