diff --git a/src/models/capture_group_patterns.rs b/src/models/capture_group_patterns.rs index 1929d6bee..17309fb5a 100644 --- a/src/models/capture_group_patterns.rs +++ b/src/models/capture_group_patterns.rs @@ -14,6 +14,7 @@ Copyright (c) 2023 Uber Technologies, Inc. use crate::{ models::Validator, utilities::{ + regex_utilities::get_all_matches_for_regex, tree_sitter_utilities::{get_all_matches_for_query, get_ts_query_parser, number_of_errors}, Instantiate, }, @@ -24,7 +25,7 @@ use serde_derive::Deserialize; use std::collections::HashMap; use tree_sitter::{Node, Query}; -use super::matches::Match; +use super::{default_configs::REGEX_QUERY_PREFIX, matches::Match}; #[pyclass] #[derive(Deserialize, Debug, Clone, Default, PartialEq, Hash, Eq)] @@ -38,12 +39,20 @@ impl CGPattern { pub(crate) fn pattern(&self) -> String { self.0.to_string() } + + pub(crate) fn extract_regex(&self) -> Result { + let mut _val = &self.pattern()[REGEX_QUERY_PREFIX.len()..]; + Regex::new(_val) + } } impl Validator for CGPattern { fn validate(&self) -> Result<(), String> { if self.pattern().starts_with("rgx ") { - panic!("Regex not supported") + return self + .extract_regex() + .map(|_| Ok(())) + .unwrap_or(Err(format!("Cannot parse the regex - {}", self.pattern()))); } let mut parser = get_ts_query_parser(); parser @@ -70,7 +79,7 @@ impl Instantiate for CGPattern { #[derive(Debug)] pub(crate) enum CompiledCGPattern { Q(Query), - R(Regex), // Regex is not yet supported + R(Regex), } impl CompiledCGPattern { @@ -103,7 +112,9 @@ impl CompiledCGPattern { replace_node, replace_node_idx, ), - CompiledCGPattern::R(_) => panic!("Regex is not yet supported!!!"), + CompiledCGPattern::R(regex) => { + get_all_matches_for_regex(node, source_code, regex, recursive, replace_node) + } } } } diff --git a/src/models/default_configs.rs b/src/models/default_configs.rs index 81ad01435..269136979 100644 --- a/src/models/default_configs.rs +++ b/src/models/default_configs.rs @@ -31,6 +31,8 @@ pub const THRIFT: &str = "thrift"; pub const STRINGS: &str = "strings"; pub const TS_SCHEME: &str = "scm"; // We support scheme files that contain tree-sitter query +pub const REGEX_QUERY_PREFIX: &str = "rgx "; + #[cfg(test)] //FIXME: Remove this hack by not passing PiranhaArguments to SourceCodeUnit pub(crate) const UNUSED_CODE_PATH: &str = "/dev/null"; diff --git a/src/models/matches.rs b/src/models/matches.rs index 202ce8ac5..193169dda 100644 --- a/src/models/matches.rs +++ b/src/models/matches.rs @@ -55,6 +55,18 @@ pub(crate) struct Match { gen_py_str_methods!(Match); impl Match { + pub(crate) fn from_regex( + mtch: ®ex::Match, matches: HashMap, source_code: &str, + ) -> Self { + Match { + matched_string: mtch.as_str().to_string(), + range: Range::from_regex_match(mtch, source_code), + matches, + associated_comma: None, + associated_comments: Vec::new(), + } + } + pub(crate) fn new( matched_string: String, range: tree_sitter::Range, matches: HashMap, ) -> Self { @@ -231,7 +243,7 @@ impl Match { serde_derive::Serialize, Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, )] #[pyclass] -struct Range { +pub(crate) struct Range { #[pyo3(get)] start_byte: usize, #[pyo3(get)] @@ -260,6 +272,32 @@ impl From for Range { } gen_py_str_methods!(Range); +impl Range { + pub(crate) fn from_regex_match(mtch: ®ex::Match, source_code: &str) -> Self { + Self { + start_byte: mtch.start(), + end_byte: mtch.end(), + start_point: position_for_offset(source_code.as_bytes(), mtch.start()), + end_point: position_for_offset(source_code.as_bytes(), mtch.end()), + } + } +} + +// Finds the position (col and row number) for a given offset. +// Copied from tree-sitter tests [https://github.com/tree-sitter/tree-sitter/blob/d0029a15273e526925a764033e9b7f18f96a7ce5/cli/src/parse.rs#L364] +fn position_for_offset(input: &[u8], offset: usize) -> Point { + let mut result = Point { row: 0, column: 0 }; + for c in &input[0..offset] { + if *c as char == '\n' { + result.row += 1; + result.column = 0; + } else { + result.column += 1; + } + } + result +} + /// A range of positions in a multi-line text document, both in terms of bytes and of /// rows and columns. #[derive( diff --git a/src/models/rule_store.rs b/src/models/rule_store.rs index d91780fc9..d69fd5f2d 100644 --- a/src/models/rule_store.rs +++ b/src/models/rule_store.rs @@ -79,7 +79,10 @@ impl RuleStore { pub(crate) fn query(&mut self, cg_pattern: &CGPattern) -> &CompiledCGPattern { let pattern = cg_pattern.pattern(); if pattern.starts_with("rgx ") { - panic!("Regex not supported.") + return &*self + .rule_query_cache + .entry(pattern) + .or_insert_with(|| CompiledCGPattern::R(cg_pattern.extract_regex().unwrap())); } &*self diff --git a/src/models/unit_tests/rule_graph_validation_test.rs b/src/models/unit_tests/rule_graph_validation_test.rs index f12d9880a..9102ff837 100644 --- a/src/models/unit_tests/rule_graph_validation_test.rs +++ b/src/models/unit_tests/rule_graph_validation_test.rs @@ -118,13 +118,3 @@ fn test_filter_bad_arg_contains_n_sibling() { .sibling_count(2) .build(); } - -#[test] -#[should_panic(expected = "Regex not supported")] -fn test_unsupported_regex() { - RuleGraphBuilder::default() - .rules(vec![ - piranha_rule! {name = "Test rule", query = "rgx (\\w+) (\\w)+"}, - ]) - .build(); -} diff --git a/src/tests/test_piranha_java.rs b/src/tests/test_piranha_java.rs index 6340b8d9d..fe7ffa6d9 100644 --- a/src/tests/test_piranha_java.rs +++ b/src/tests/test_piranha_java.rs @@ -62,10 +62,11 @@ create_rewrite_tests! { test_non_seed_user_rule: "non_seed_user_rule", 1, substitutions = substitutions! {"input_type_name" => "ArrayList"}; test_insert_field_and_initializer: "insert_field_and_initializer", 1; test_user_option_delete_if_empty: "user_option_delete_if_empty", 1; - test_user_option_do_not_delete_if_empty : "user_option_do_not_delete_if_empty", 1, delete_file_if_empty =false; + test_user_option_do_not_delete_if_empty : "user_option_do_not_delete_if_empty", 1, delete_file_if_empty = false; test_new_line_character_used_in_string_literal: "new_line_character_used_in_string_literal", 1; test_java_delete_method_invocation_argument: "delete_method_invocation_argument", 1; test_java_delete_method_invocation_argument_no_op: "delete_method_invocation_argument_no_op", 0; + test_regex_based_matcher: "regex_based_matcher", 1, cleanup_comments = true; } create_match_tests! { diff --git a/src/utilities/mod.rs b/src/utilities/mod.rs index 7035b8567..7ca6df029 100644 --- a/src/utilities/mod.rs +++ b/src/utilities/mod.rs @@ -11,6 +11,7 @@ Copyright (c) 2023 Uber Technologies, Inc. limitations under the License. */ +pub(crate) mod regex_utilities; pub(crate) mod tree_sitter_utilities; use std::collections::HashMap; use std::error::Error; diff --git a/src/utilities/regex_utilities.rs b/src/utilities/regex_utilities.rs new file mode 100644 index 000000000..3cf8d54d2 --- /dev/null +++ b/src/utilities/regex_utilities.rs @@ -0,0 +1,73 @@ +/* + Copyright (c) 2023 Uber Technologies, Inc. + +

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file + except in compliance with the License. You may obtain a copy of the License at +

http://www.apache.org/licenses/LICENSE-2.0 + +

Unless required by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + express or implied. See the License for the specific language governing permissions and + limitations under the License. +*/ + +use crate::models::matches::Match; +use itertools::Itertools; +use regex::Regex; +use std::collections::HashMap; +use tree_sitter::Node; + +/// Applies the query upon the given `node`, and gets all the matches +/// # Arguments +/// * `node` - the root node to apply the query upon +/// * `source_code` - the corresponding source code string for the node. +/// * `recursive` - if `true` it matches the query to `self` and `self`'s sub-ASTs, else it matches the `query` only to `self`. +/// * `replace_node` - node to replace +/// +/// # Returns +/// List containing all the matches against `node` +pub(crate) fn get_all_matches_for_regex( + node: &Node, source_code: String, regex: &Regex, recursive: bool, replace_node: Option, +) -> Vec { + let all_captures = regex.captures_iter(&source_code).collect_vec(); + let names = regex.capture_names().collect_vec(); + let mut all_matches = vec![]; + for captures in all_captures { + // Check if the range of the self (node), and the range of outermost node captured by the query are equal. + let range_matches_node = node.start_byte() == captures.get(0).unwrap().start() + && node.end_byte() == captures.get(0).unwrap().end(); + let range_matches_inside_node = node.start_byte() <= captures.get(0).unwrap().start() + && node.end_byte() >= captures.get(0).unwrap().end(); + if (recursive && range_matches_inside_node) || range_matches_node { + let replace_node_match = if let Some(ref rn) = replace_node { + captures + .name(rn) + .unwrap_or_else(|| panic!("The tag {rn} provided in the replace node is not present")) + } else { + captures.get(0).unwrap() + }; + let matches = extract_captures(&captures, &names); + all_matches.push(Match::from_regex( + &replace_node_match, + matches, + &source_code, + )); + } + } + all_matches +} + +// Creates a hashmap from the capture group(name) to the corresponding code snippet. +fn extract_captures( + captures: ®ex::Captures<'_>, names: &Vec>, +) -> HashMap { + names + .iter() + .flatten() + .flat_map(|x| { + captures + .name(x) + .map(|v| (x.to_string(), v.as_str().to_string())) + }) + .collect() +} diff --git a/test-resources/java/regex_based_matcher/configurations/edges.toml b/test-resources/java/regex_based_matcher/configurations/edges.toml new file mode 100644 index 000000000..f6ba0a5c6 --- /dev/null +++ b/test-resources/java/regex_based_matcher/configurations/edges.toml @@ -0,0 +1,31 @@ +# Copyright (c) 2023 Uber Technologies, Inc. +# +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 +# +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +[[edges]] +scope = "File" +from = "update_import" +to = ["update_list_int"] + +[[edges]] +scope = "Method" +from = "update_list_int" +to = ["update_add"] + + +[[edges]] +scope = "File" +from = "delete_import_our_map_of" +to = ["change_type_from_OurHashMap", "add_import_For_hashmap"] + +[[edges]] +scope = "Method" +from = "change_type_from_OurHashMap" +to = ["update_push"] diff --git a/test-resources/java/regex_based_matcher/configurations/rules.toml b/test-resources/java/regex_based_matcher/configurations/rules.toml new file mode 100644 index 000000000..322a4a968 --- /dev/null +++ b/test-resources/java/regex_based_matcher/configurations/rules.toml @@ -0,0 +1,105 @@ +# Copyright (c) 2023 Uber Technologies, Inc. +# +#

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file +# except in compliance with the License. You may obtain a copy of the License at +#

http://www.apache.org/licenses/LICENSE-2.0 +# +#

Unless required by applicable law or agreed to in writing, software distributed under the +# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +# Replace foo().bar().baz() with `true` inside methods not nnotated as @DoNotCleanup +[[rules]] +name = "replace_call" +query = """rgx (?Pfoo\\(\\)\\.bar\\(\\)\\.baz\\(\\))""" +replace_node = "n1" +replace = "true" +groups = ["replace_expression_with_boolean_literal"] +[[rules.filters]] +enclosing_node = """(method_declaration) @md""" +not_contains = ["""rgx @DoNotCleanup"""] + +# Before: +# abc().def().ghi() +# abc().fed().ghi() +[[rules]] +name = "replace_call_def_fed" +query = """rgx (?Pabc\\(\\)\\.(?Pdef)\\(\\)\\.ghi\\(\\))""" +replace_node = "m_def" +replace = "fed" + + +# The below three rules do a dummy type migration from OurListOfInteger to List + +# Updates the import statement from `out.list.OurListOfInteger` to `java.util.List` +[[rules]] +name = "update_import" +query = """rgx (?Pour\\.list\\.OurListOfInteger)""" +replace_node = "n" +replace = "java.util.List" + +# Updates the type of local variables from `OurListOfInteger` to `List` +[[rules]] +name = "update_list_int" +query = """rgx (?P(?POurListOf(?PInteger))\\s*(?P\\w+)\\s*=.*;)""" +replace_node = "type" +replace = "List<@param>" +is_seed_rule = false +[[rules.filter]] +enclosing_node = "(method_declaration) @cmd" + +# Updates the relevant callsite from `addToOurList` to `add` +[[rules]] +name = "update_add" +query = """rgx (?P@name\\.(?PaddToOurList)\\(\\w+\\))""" +replace_node = "m_name" +replace = "add" +holes = ["name"] +is_seed_rule = false + + +# The below three rules do a dummy type migration OurMapOf{T1}{T2} to HashMap. For example, from OurMapOfStringInteger to HashMap. +# This is to exercise non-constant regex matches for replace_node. + +# Deletes the import statement `our.map.OurMapOf...` +[[rules]] +name = "delete_import_our_map_of" +query = """rgx (?Pimport our\\.map\\.OurMapOf\\w+;)""" +replace_node = "n" +replace = "" + +# Adds Import to java.util.Hashmap if absent +[[rules]] +name = "add_import_For_hashmap" +query = """(package_declaration) @pkg""" +replace_node = "pkg" +replace = "@pkg\nimport java.util.HashMap;" +is_seed_rule = false +[[rules.filters]] +enclosing_node = "(program) @prgrm" +not_contains = [ + "((import_declaration) @im (#match? @im \"java.util.HashMap\"))", +] + +# Before: +# OurMapOfStringInteger +# After: +# HashMap +[[rules]] +name = "change_type_from_OurHashMap" +query = """rgx (?P(?P\\bOurMapOf(?P[A-Z]\\w+)(?P[A-Z]\\w+))\\s*(?P\\w+)\\s*=.*;)""" +replace_node = "type" +replace = "HashMap<@param1, @param2>" +is_seed_rule = false +[[rules.filter]] +enclosing_node = "(method_declaration) @cmd" + +# Updates the relevant callsite from `pushIntoOurMap` to `push` +[[rules]] +name = "update_push" +query = """rgx (?P@name\\.(?PpushIntoOurMap)\\(.*\\))""" +replace_node = "m_name" +replace = "push" +holes = ["name"] +is_seed_rule = false diff --git a/test-resources/java/regex_based_matcher/expected/Sample.java b/test-resources/java/regex_based_matcher/expected/Sample.java new file mode 100644 index 000000000..41282de1f --- /dev/null +++ b/test-resources/java/regex_based_matcher/expected/Sample.java @@ -0,0 +1,55 @@ +package com.uber.piranha; + +import java.util.HashMap; +import java.util.List; +import not.our.map.NotOurMapOfDoubleInteger; + +class A { + + void foobar() { + System.out.println("Hello World!"); + System.out.println(true); + } + + @DoNotCleanup + void barfn() { + boolean b = foo().bar().baz(); + System.out.println(b); + } + + void foofn() { + int total = abc().fed().ghi(); + } + + void someTypeChange() { + // Will get updated + List a = getList(); + Integer item = getItem(); + a.add(item); + + // Will not get updated + List b = getListStr(); + String item = getItemStr(); + b.add(item); + } + + void someOtherTypeChange() { + // Will get updated + HashMap siMap = getMapSI(); + String sKey = getStrKey(); + Integer iItem = getIItem(); + siMap.push(sKey, iItem); + + // Will get updated + HashMap lfMap = getMapLF(); + Long lKey = getLongKey(); + Float fItem = getFItem(); + lfMap.push(lKey, fItem); + + // Will not get updated + NotOurMapOfDoubleInteger dlMap = getMapDL(); + Double dKey = getDoubleKey(); + dlMap.pushIntoOurMap(dKey, iItem); + } + +} diff --git a/test-resources/java/regex_based_matcher/input/Sample.java b/test-resources/java/regex_based_matcher/input/Sample.java new file mode 100644 index 000000000..97640ebfa --- /dev/null +++ b/test-resources/java/regex_based_matcher/input/Sample.java @@ -0,0 +1,60 @@ +package com.uber.piranha; + +import our.list.OurListOfInteger; +import our.map.OurMapOfStringInteger; +import our.map.OurMapOfLongFloat; +import not.our.map.NotOurMapOfDoubleInteger; + +class A { + + void foobar() { + // Will be removed + boolean b = foo().bar().baz(); + if (b) { + System.out.println("Hello World!"); + } + System.out.println(b); + } + + @DoNotCleanup + void barfn() { + boolean b = foo().bar().baz(); + System.out.println(b); + } + + void foofn() { + int total = abc().def().ghi(); + } + + void someTypeChange() { + // Will get updated + OurListOfInteger a = getList(); + Integer item = getItem(); + a.addToOurList(item); + + // Will not get updated + List b = getListStr(); + String item = getItemStr(); + b.add(item); + } + + void someOtherTypeChange() { + // Will get updated + OurMapOfStringInteger siMap = getMapSI(); + String sKey = getStrKey(); + Integer iItem = getIItem(); + siMap.pushIntoOurMap(sKey, iItem); + + // Will get updated + OurMapOfLongFloat lfMap = getMapLF(); + Long lKey = getLongKey(); + Float fItem = getFItem(); + lfMap.pushIntoOurMap(lKey, fItem); + + // Will not get updated + NotOurMapOfDoubleInteger dlMap = getMapDL(); + Double dKey = getDoubleKey(); + dlMap.pushIntoOurMap(dKey, iItem); + } + +}