diff --git a/Cargo.lock b/Cargo.lock index e5e4c63..214ce2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -69,7 +69,6 @@ name = "common" version = "1.0.0" dependencies = [ "indoc", - "lexical-sort", "pep440_rs", "pep508_rs", "rstest", @@ -528,6 +527,7 @@ version = "2.4.3" dependencies = [ "common", "indoc", + "lexical-sort", "pyo3", "regex", "rstest", diff --git a/common/Cargo.toml b/common/Cargo.toml index 1ac34de..65b3bc9 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -11,7 +11,6 @@ edition = "2021" taplo = { version = "0.13.2" } # formatter pep508_rs = { version = "0.6.1" } pep440_rs = { version = "0.6.5" } # align up with pep508_rs for now https://github.com/konstin/pep508_rs/issues/19 -lexical-sort = { version = "0.3.1" } [dev-dependencies] rstest = { version = "0.23.0" } # parametrized tests diff --git a/common/src/array.rs b/common/src/array.rs index 8b6f3ad..acfa502 100644 --- a/common/src/array.rs +++ b/common/src/array.rs @@ -1,143 +1,152 @@ use std::cell::RefCell; +use std::cmp::Ordering; use std::collections::HashMap; - -use lexical_sort::{natural_lexical_cmp, StringSort}; +use std::hash::Hash; use taplo::syntax::SyntaxKind::{ARRAY, COMMA, NEWLINE, STRING, VALUE, WHITESPACE}; use taplo::syntax::{SyntaxElement, SyntaxKind, SyntaxNode}; use crate::create::{make_comma, make_newline}; use crate::string::{load_text, update_content}; +use crate::util::{find_first, iter}; pub fn transform(node: &SyntaxNode, transform: &F) where F: Fn(&str) -> String, { - for array in node.children_with_tokens() { - if array.kind() == ARRAY { - for array_entry in array.as_node().unwrap().children_with_tokens() { - if array_entry.kind() == VALUE { - update_content(array_entry.as_node().unwrap(), transform); - } - } - } - } + iter(node, [ARRAY, VALUE].as_ref(), &|array_entry| { + update_content(array_entry, transform); + }); } #[allow(clippy::range_plus_one, clippy::too_many_lines)] -pub fn sort(node: &SyntaxNode, transform: F) +pub fn sort(node: &SyntaxNode, to_key: K, cmp: &C) where - F: Fn(&str) -> String, + K: Fn(&SyntaxNode) -> Option, + C: Fn(&T, &T) -> Ordering, + T: Clone + Eq + Hash, { - for array in node.children_with_tokens() { - if array.kind() == ARRAY { - let array_node = array.as_node().unwrap(); - let has_trailing_comma = array_node - .children_with_tokens() - .map(|x| x.kind()) - .filter(|x| *x == COMMA || *x == VALUE) - .last() - == Some(COMMA); - let multiline = array_node.children_with_tokens().any(|e| e.kind() == NEWLINE); - let mut value_set = Vec::>::new(); - let entry_set = RefCell::new(Vec::::new()); - let mut key_to_pos = HashMap::::new(); + iter(node, [ARRAY].as_ref(), &|array| { + let has_trailing_comma = array + .children_with_tokens() + .map(|x| x.kind()) + .filter(|x| *x == COMMA || *x == VALUE) + .last() + == Some(COMMA); + let multiline = array.children_with_tokens().any(|e| e.kind() == NEWLINE); - let mut add_to_value_set = |entry: String| { - let mut entry_set_borrow = entry_set.borrow_mut(); - if !entry_set_borrow.is_empty() { - key_to_pos.insert(entry, value_set.len()); - value_set.push(entry_set_borrow.clone()); - entry_set_borrow.clear(); - } - }; - let mut entries = Vec::::new(); - let mut has_value = false; - let mut previous_is_bracket_open = false; - let mut entry_value = String::new(); - let mut count = 0; + let mut entries = Vec::::new(); + let mut order_sets = Vec::>::new(); + let mut key_to_order_set = HashMap::::new(); + let current_set = RefCell::new(Vec::::new()); + let mut current_set_value: Option = None; + let mut previous_is_bracket_open = false; + + let mut add_to_order_sets = |entry: T| { + let mut entry_set_borrow = current_set.borrow_mut(); + if !entry_set_borrow.is_empty() { + key_to_order_set.insert(entry, order_sets.len()); + order_sets.push(entry_set_borrow.clone()); + entry_set_borrow.clear(); + } + }; + + let mut count = 0; - for entry in array_node.children_with_tokens() { - count += 1; - if previous_is_bracket_open { - // make sure ends with trailing comma - if entry.kind() == NEWLINE || entry.kind() == WHITESPACE { - continue; + // collect elements to order into to_order_sets, the rest goes into entries + for entry in array.children_with_tokens() { + count += 1; + if previous_is_bracket_open { + // make sure ends with trailing comma + if entry.kind() == NEWLINE || entry.kind() == WHITESPACE { + continue; + } + previous_is_bracket_open = false; + } + match &entry.kind() { + SyntaxKind::BRACKET_START => { + entries.push(entry); + if multiline { + entries.push(make_newline()); } - previous_is_bracket_open = false; + previous_is_bracket_open = true; } - match &entry.kind() { - SyntaxKind::BRACKET_START => { - entries.push(entry); - if multiline { - entries.push(make_newline()); + SyntaxKind::BRACKET_END => { + match current_set_value.take() { + None => { + entries.extend(current_set.borrow_mut().clone()); } - previous_is_bracket_open = true; - } - SyntaxKind::BRACKET_END => { - if has_value { - add_to_value_set(entry_value.clone()); - } else { - entries.extend(entry_set.borrow_mut().clone()); + Some(val) => { + add_to_order_sets(val); } - entries.push(entry); } - VALUE => { - if has_value { + entries.push(entry); + } + VALUE => { + match current_set_value.take() { + None => {} + Some(val) => { if multiline { - entry_set.borrow_mut().push(make_newline()); - } - add_to_value_set(entry_value.clone()); - } - has_value = true; - let value_node = entry.as_node().unwrap(); - let mut found_string = false; - for child in value_node.children_with_tokens() { - let kind = child.kind(); - if kind == STRING { - entry_value = transform(load_text(child.as_token().unwrap().text(), STRING).as_str()); - found_string = true; - break; + current_set.borrow_mut().push(make_newline()); } + add_to_order_sets(val); } - if !found_string { - // abort if not correct types - return; - } - entry_set.borrow_mut().push(entry); - entry_set.borrow_mut().push(make_comma()); } - NEWLINE => { - entry_set.borrow_mut().push(entry); - if has_value { - add_to_value_set(entry_value.clone()); - has_value = false; - } - } - COMMA => {} - _ => { - entry_set.borrow_mut().push(entry); + let value_node = entry.as_node().unwrap(); + current_set_value = to_key(value_node); + + current_set.borrow_mut().push(entry); + current_set.borrow_mut().push(make_comma()); + } + NEWLINE => { + current_set.borrow_mut().push(entry); + if current_set_value.is_some() { + add_to_order_sets(current_set_value.unwrap()); + current_set_value = None; } } + COMMA => {} + _ => { + current_set.borrow_mut().push(entry); + } } + } - let mut order: Vec = key_to_pos.clone().into_keys().collect(); - order.string_sort_unstable(natural_lexical_cmp); - let end = entries.split_off(if multiline { 2 } else { 1 }); - for key in order { - entries.extend(value_set[key_to_pos[&key]].clone()); - } - entries.extend(end); - array_node.splice_children(0..count, entries); - if !has_trailing_comma { - if let Some((i, _)) = array_node - .children_with_tokens() - .enumerate() - .filter(|(_, x)| x.kind() == COMMA) - .last() - { - array_node.splice_children(i..i + 1, vec![]); - } + let trailing_content = entries.split_off(if multiline { 2 } else { 1 }); + let mut order: Vec = key_to_order_set.keys().cloned().collect(); + order.sort_by(&cmp); + for key in order { + entries.extend(order_sets[key_to_order_set[&key]].clone()); + } + entries.extend(trailing_content); + array.splice_children(0..count, entries); + + if !has_trailing_comma { + if let Some((i, _)) = array + .children_with_tokens() + .enumerate() + .filter(|(_, x)| x.kind() == COMMA) + .last() + { + array.splice_children(i..i + 1, vec![]); } } - } + }); +} + +#[allow(clippy::range_plus_one, clippy::too_many_lines)] +pub fn sort_strings(node: &SyntaxNode, to_key: K, cmp: &C) +where + K: Fn(String) -> String, + C: Fn(&String, &String) -> Ordering, + T: Clone + Eq + Hash, +{ + sort( + node, + |e| -> Option { + find_first(e, &[STRING], &|s| -> String { + to_key(load_text(s.as_token().unwrap().text(), STRING)) + }) + }, + cmp, + ); } diff --git a/common/src/table.rs b/common/src/table.rs index cb83532..f526e85 100644 --- a/common/src/table.rs +++ b/common/src/table.rs @@ -256,6 +256,22 @@ where } } +pub fn find_key(table: &SyntaxNode, key: &str) -> Option { + let mut current_key = String::new(); + for table_entry in table.children_with_tokens() { + if table_entry.kind() == ENTRY { + for entry in table_entry.as_node().unwrap().children_with_tokens() { + if entry.kind() == KEY { + current_key = entry.as_node().unwrap().text().to_string().trim().to_string(); + } else if entry.kind() == VALUE && current_key == key { + return Some(entry.as_node().unwrap().clone()); + } + } + } + } + None +} + pub fn collapse_sub_tables(tables: &mut Tables, name: &str) { let h2p = tables.header_to_pos.clone(); let sub_name_prefix = format!("{name}."); diff --git a/common/src/tests/array_tests.rs b/common/src/tests/array_tests.rs index 4439e50..cd85683 100644 --- a/common/src/tests/array_tests.rs +++ b/common/src/tests/array_tests.rs @@ -4,7 +4,7 @@ use taplo::formatter::{format_syntax, Options}; use taplo::parser::parse; use taplo::syntax::SyntaxKind::{ENTRY, VALUE}; -use crate::array::{sort, transform}; +use crate::array::{sort_strings, transform}; use crate::pep508::format_requirement; #[rstest] @@ -148,7 +148,9 @@ fn test_order_array(#[case] start: &str, #[case] expected: &str) { if children.kind() == ENTRY { for entry in children.as_node().unwrap().children_with_tokens() { if entry.kind() == VALUE { - sort(entry.as_node().unwrap(), str::to_lowercase); + sort_strings::(entry.as_node().unwrap(), |s| s.to_lowercase(), &|lhs, rhs| { + lhs.cmp(rhs) + }); } } } @@ -172,7 +174,9 @@ fn test_reorder_no_trailing_comma(#[case] start: &str, #[case] expected: &str) { if children.kind() == ENTRY { for entry in children.as_node().unwrap().children_with_tokens() { if entry.kind() == VALUE { - sort(entry.as_node().unwrap(), str::to_lowercase); + sort_strings::(entry.as_node().unwrap(), |s| s.to_lowercase(), &|lhs, rhs| { + lhs.cmp(rhs) + }); } } } diff --git a/common/src/util.rs b/common/src/util.rs index 6e66dee..f309335 100644 --- a/common/src/util.rs +++ b/common/src/util.rs @@ -1,6 +1,6 @@ -use taplo::syntax::{SyntaxKind, SyntaxNode}; +use taplo::syntax::{SyntaxElement, SyntaxKind, SyntaxNode}; -pub fn iter(node: &SyntaxNode, paths: &[SyntaxKind], transform: &F) +pub fn iter(node: &SyntaxNode, paths: &[SyntaxKind], handle: &F) where F: Fn(&SyntaxNode), { @@ -8,10 +8,26 @@ where if entry.kind() == paths[0] { let found = entry.as_node().unwrap(); if paths.len() == 1 { - transform(found); + handle(found); } else { - iter(found, &paths[1..], transform); + iter(found, &paths[1..], handle); } } } } + +pub fn find_first(node: &SyntaxNode, paths: &[SyntaxKind], extract: &F) -> Option +where + F: Fn(SyntaxElement) -> T, +{ + for entry in node.children_with_tokens() { + if entry.kind() == paths[0] { + if paths.len() == 1 { + return Some(extract(entry)); + } else { + find_first(entry.as_node().unwrap(), &paths[1..], extract); + } + } + } + None +} diff --git a/pyproject-fmt/Cargo.toml b/pyproject-fmt/Cargo.toml index eb171fb..92aab2c 100644 --- a/pyproject-fmt/Cargo.toml +++ b/pyproject-fmt/Cargo.toml @@ -16,6 +16,7 @@ crate-type = ["cdylib"] common = {path = "../common" } regex = { version = "1.11.0" } pyo3 = { version = "0.22.5", features = ["abi3-py38"] } # integration with Python +lexical-sort = { version = "0.3.1" } [features] extension-module = ["pyo3/extension-module"] diff --git a/pyproject-fmt/rust/src/build_system.rs b/pyproject-fmt/rust/src/build_system.rs index 74708ec..89d743b 100644 --- a/pyproject-fmt/rust/src/build_system.rs +++ b/pyproject-fmt/rust/src/build_system.rs @@ -1,6 +1,7 @@ -use common::array::{sort, transform}; +use common::array::{sort_strings, transform}; use common::pep508::{format_requirement, get_canonic_requirement_name}; use common::table::{for_entries, reorder_table_keys, Tables}; +use lexical_sort::{lexical_cmp, natural_lexical_cmp}; pub fn fix(tables: &Tables, keep_full_version: bool) { let table_element = tables.get("build-system"); @@ -11,10 +12,14 @@ pub fn fix(tables: &Tables, keep_full_version: bool) { for_entries(table, &mut |key, entry| match key.as_str() { "requires" => { transform(entry, &|s| format_requirement(s, keep_full_version)); - sort(entry, |e| get_canonic_requirement_name(e).to_lowercase()); + sort_strings::( + entry, + |s| get_canonic_requirement_name(s.as_str()).to_lowercase(), + &|lhs, rhs| natural_lexical_cmp(lhs, rhs), + ); } "backend-path" => { - sort(entry, str::to_lowercase); + sort_strings::(entry, |s| s.to_lowercase(), &|lhs, rhs| lexical_cmp(lhs, rhs)); } _ => {} }); diff --git a/pyproject-fmt/rust/src/dependency_groups.rs b/pyproject-fmt/rust/src/dependency_groups.rs index 788ac4a..c757aac 100644 --- a/pyproject-fmt/rust/src/dependency_groups.rs +++ b/pyproject-fmt/rust/src/dependency_groups.rs @@ -1,9 +1,11 @@ use common::array::{sort, transform}; use common::pep508::{format_requirement, get_canonic_requirement_name}; -use common::string::update_content; -use common::table::{collapse_sub_tables, for_entries, reorder_table_keys, Tables}; -use common::taplo::syntax::SyntaxKind::{ARRAY, ENTRY, INLINE_TABLE, VALUE}; +use common::string::{load_text, update_content}; +use common::table::{collapse_sub_tables, find_key, for_entries, reorder_table_keys, Tables}; +use common::taplo::syntax::SyntaxKind::{ARRAY, ENTRY, INLINE_TABLE, STRING, VALUE}; use common::util::iter; +use lexical_sort::natural_lexical_cmp; +use std::cmp::Ordering; pub fn fix(tables: &mut Tables, keep_full_version: bool) { collapse_sub_tables(tables, "dependency-groups"); @@ -16,14 +18,52 @@ pub fn fix(tables: &mut Tables, keep_full_version: bool) { for_entries(table, &mut |_key, entry| { // format dependency specifications transform(entry, &|s| format_requirement(s, keep_full_version)); - // update include-group key to double-quoted string + + // update inline table values to double-quoted string, e.g. include-group iter(entry, [ARRAY, VALUE, INLINE_TABLE, ENTRY, VALUE].as_ref(), &|node| { update_content(node, |s| String::from(s)); }); + // sort array elements - sort(entry, |e| { - get_canonic_requirement_name(e).to_lowercase() + " " + &format_requirement(e, keep_full_version) - }); + sort::<(u8, String, String), _, _>( + entry, + |node| { + for child in node.children_with_tokens() { + match child.kind() { + STRING => { + let val = load_text(child.as_token().unwrap().text(), STRING); + let package_name = get_canonic_requirement_name(val.as_str()).to_lowercase(); + return Some((0, package_name, val)); + } + INLINE_TABLE => { + match find_key(child.as_node().unwrap(), "include-group") { + None => {} + Some(n) => { + return Some(( + 1, + load_text(n.first_token().unwrap().text(), STRING), + String::from(""), + )); + } + }; + } + _ => {} + } + } + None + }, + &|lhs, rhs| { + let mut res = lhs.0.cmp(&rhs.0); + if res == Ordering::Equal { + res = natural_lexical_cmp(lhs.1.as_str(), rhs.1.as_str()); + if res == Ordering::Equal { + res = natural_lexical_cmp(lhs.2.as_str(), rhs.2.as_str()); + } + } + res + }, + ); }); + reorder_table_keys(table, &["", "dev", "test", "type", "docs"]); } diff --git a/pyproject-fmt/rust/src/project.rs b/pyproject-fmt/rust/src/project.rs index 0a27aba..3130617 100644 --- a/pyproject-fmt/rust/src/project.rs +++ b/pyproject-fmt/rust/src/project.rs @@ -1,18 +1,18 @@ -use std::cell::RefMut; - +use common::array::{sort, sort_strings, transform}; +use common::create::{make_array, make_array_entry, make_comma, make_entry_of_string, make_newline}; +use common::pep508::{format_requirement, get_canonic_requirement_name}; +use common::string::{load_text, update_content}; +use common::table::{collapse_sub_tables, for_entries, reorder_table_keys, Tables}; use common::taplo::syntax::SyntaxKind::{ ARRAY, BRACKET_END, BRACKET_START, COMMA, ENTRY, IDENT, INLINE_TABLE, KEY, NEWLINE, STRING, VALUE, }; use common::taplo::syntax::{SyntaxElement, SyntaxNode}; use common::taplo::util::StrExt; use common::taplo::HashSet; +use lexical_sort::natural_lexical_cmp; use regex::Regex; - -use common::array::{sort, transform}; -use common::create::{make_array, make_array_entry, make_comma, make_entry_of_string, make_newline}; -use common::pep508::{format_requirement, get_canonic_requirement_name}; -use common::string::{load_text, update_content}; -use common::table::{collapse_sub_tables, for_entries, reorder_table_keys, Tables}; +use std::cell::RefMut; +use std::cmp::Ordering; pub fn fix( tables: &mut Tables, @@ -55,17 +55,34 @@ pub fn fix( } "dependencies" | "optional-dependencies" => { transform(entry, &|s| format_requirement(s, keep_full_version)); - sort(entry, |e| { - get_canonic_requirement_name(e).to_lowercase() + " " + &format_requirement(e, keep_full_version) - }); + sort::<(String, String), _, _>( + entry, + |node| { + for child in node.children_with_tokens() { + if let STRING = child.kind() { + let val = load_text(child.as_token().unwrap().text(), STRING); + let package_name = get_canonic_requirement_name(val.as_str()).to_lowercase(); + return Some((package_name, val)); + } + } + None + }, + &|lhs, rhs| { + let mut res = natural_lexical_cmp(lhs.0.as_str(), rhs.0.as_str()); + if res == Ordering::Equal { + res = natural_lexical_cmp(lhs.1.as_str(), rhs.1.as_str()); + } + res + }, + ); } "dynamic" | "keywords" => { transform(entry, &|s| String::from(s)); - sort(entry, str::to_lowercase); + sort_strings::(entry, |s| s.to_lowercase(), &|lhs, rhs| natural_lexical_cmp(lhs, rhs)); } "classifiers" => { transform(entry, &|s| String::from(s)); - sort(entry, str::to_lowercase); + sort_strings::(entry, |s| s.to_lowercase(), &|lhs, rhs| natural_lexical_cmp(lhs, rhs)); } _ => {} }); @@ -73,7 +90,7 @@ pub fn fix( generate_classifiers(table, max_supported_python, min_supported_python); for_entries(table, &mut |key, entry| { if key.as_str() == "classifiers" { - sort(entry, str::to_lowercase); + sort_strings::(entry, |s| s.to_lowercase(), &|lhs, rhs| natural_lexical_cmp(lhs, rhs)); } }); reorder_table_keys( diff --git a/pyproject-fmt/rust/src/ruff.rs b/pyproject-fmt/rust/src/ruff.rs index d4da1a5..8f35c8d 100644 --- a/pyproject-fmt/rust/src/ruff.rs +++ b/pyproject-fmt/rust/src/ruff.rs @@ -1,6 +1,7 @@ -use common::array::{sort, transform}; +use common::array::{sort_strings, transform}; use common::string::update_content; use common::table::{collapse_sub_tables, for_entries, reorder_table_keys, Tables}; +use lexical_sort::natural_lexical_cmp; #[allow(clippy::too_many_lines)] pub fn fix(tables: &mut Tables) { @@ -92,7 +93,7 @@ pub fn fix(tables: &mut Tables) { | "lint.pylint.allow-dunder-method-names" | "lint.pylint.allow-magic-value-types" => { transform(entry, &|s| String::from(s)); - sort(entry, str::to_lowercase); + sort_strings::(entry, |s| s.to_lowercase(), &|lhs, rhs| natural_lexical_cmp(lhs, rhs)); } "lint.isort.section-order" => { transform(entry, &|s| String::from(s)); @@ -100,7 +101,7 @@ pub fn fix(tables: &mut Tables) { _ => { if key.starts_with("lint.extend-per-file-ignores.") || key.starts_with("lint.per-file-ignores.") { transform(entry, &|s| String::from(s)); - sort(entry, str::to_lowercase); + sort_strings::(entry, |s| s.to_lowercase(), &|lhs, rhs| natural_lexical_cmp(lhs, rhs)); } } }); diff --git a/pyproject-fmt/rust/src/tests/dependency_groups_tests.rs b/pyproject-fmt/rust/src/tests/dependency_groups_tests.rs index 4280fd7..3bb7c46 100644 --- a/pyproject-fmt/rust/src/tests/dependency_groups_tests.rs +++ b/pyproject-fmt/rust/src/tests/dependency_groups_tests.rs @@ -117,9 +117,9 @@ fn evaluate(start: &str, keep_full_version: bool) -> String { #[case::include_many_groups( indoc ! {r#" [dependency-groups] - test = ["a>1"] - docs=["b==1"] - all=[{include-group='test'}, {include-group='docs'}] + all=['c<1', {include-group='test'}, {include-group='docs'}, 'd>1'] + docs = ['b==1'] + test = ['a>1'] "#}, indoc ! {r#" [dependency-groups] @@ -130,8 +130,10 @@ fn evaluate(start: &str, keep_full_version: bool) -> String { "b==1", ] all = [ - { include-group = "test" }, + "c<1", + "d>1", { include-group = "docs" }, + { include-group = "test" }, ] "#}, false,