Skip to content

Commit

Permalink
[pylint] - use sets when possible for PLR1714 autofix (`repeated-…
Browse files Browse the repository at this point in the history
…equality-comparison`) (#14372)
  • Loading branch information
diceroll123 authored Nov 18, 2024
1 parent 38a385f commit 5776535
Show file tree
Hide file tree
Showing 5 changed files with 615 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@
foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets

foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets

foo == 1 or foo == True # Different types, same hashed value

foo == 1 or foo == 1.0 # Different types, same hashed value

foo == False or foo == 0 # Different types, same hashed value

foo == 0.0 or foo == 0j # Different types, same hashed value
25 changes: 23 additions & 2 deletions crates/ruff_linter/src/rules/pylint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ mod tests {
use crate::registry::Rule;
use crate::rules::pylint;

use crate::assert_messages;
use crate::settings::types::PythonVersion;
use crate::settings::types::{PreviewMode, PythonVersion};
use crate::settings::LinterSettings;
use crate::test::test_path;
use crate::{assert_messages, settings};

#[test_case(Rule::SingledispatchMethod, Path::new("singledispatch_method.py"))]
#[test_case(
Expand Down Expand Up @@ -405,4 +405,25 @@ mod tests {
assert_messages!(diagnostics);
Ok(())
}

#[test_case(
Rule::RepeatedEqualityComparison,
Path::new("repeated_equality_comparison.py")
)]
fn preview_rules(rule_code: Rule, path: &Path) -> Result<()> {
let snapshot = format!(
"preview__{}_{}",
rule_code.noqa_code(),
path.to_string_lossy()
);
let diagnostics = test_path(
Path::new("pylint").join(path).as_path(),
&settings::LinterSettings {
preview: PreviewMode::Enabled,
..settings::LinterSettings::for_rule(rule_code)
},
)?;
assert_messages!(snapshot, diagnostics);
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ use crate::Locator;
/// If the items are hashable, use a `set` for efficiency; otherwise, use a
/// `tuple`.
///
/// In [preview], this rule will try to determine if the values are hashable
/// and the fix will use a `set` if they are. If unable to determine, the fix
/// will use a `tuple` and continue to suggest the use of a `set`.
///
/// ## Example
/// ```python
/// foo == "bar" or foo == "baz" or foo == "qux"
Expand All @@ -42,21 +46,29 @@ use crate::Locator;
/// - [Python documentation: Comparisons](https://docs.python.org/3/reference/expressions.html#comparisons)
/// - [Python documentation: Membership test operations](https://docs.python.org/3/reference/expressions.html#membership-test-operations)
/// - [Python documentation: `set`](https://docs.python.org/3/library/stdtypes.html#set)
///
/// [preview]: https://docs.astral.sh/ruff/preview/
#[violation]
pub struct RepeatedEqualityComparison {
expression: SourceCodeSnippet,
all_hashable: bool,
}

impl AlwaysFixableViolation for RepeatedEqualityComparison {
#[derive_message_formats]
fn message(&self) -> String {
if let Some(expression) = self.expression.full_display() {
format!(
"Consider merging multiple comparisons: `{expression}`. Use a `set` if the elements are hashable."
)
} else {
"Consider merging multiple comparisons. Use a `set` if the elements are hashable."
.to_string()
match (self.expression.full_display(), self.all_hashable) {
(Some(expression), false) => {
format!("Consider merging multiple comparisons: `{expression}`. Use a `set` if the elements are hashable.")
}
(Some(expression), true) => {
format!("Consider merging multiple comparisons: `{expression}`.")
}
(None, false) => {
"Consider merging multiple comparisons. Use a `set` if the elements are hashable."
.to_string()
}
(None, true) => "Consider merging multiple comparisons.".to_string(),
}
}

Expand Down Expand Up @@ -121,14 +133,23 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
continue;
}

// if we can determine that all the values are hashable, we can use a set
// TODO: improve with type inference
let all_hashable = checker.settings.preview.is_enabled()
&& comparators
.iter()
.all(|comparator| comparator.is_literal_expr());

let mut diagnostic = Diagnostic::new(
RepeatedEqualityComparison {
expression: SourceCodeSnippet::new(merged_membership_test(
expr,
bool_op.op,
&comparators,
checker.locator(),
all_hashable,
)),
all_hashable,
},
bool_op.range(),
);
Expand All @@ -140,6 +161,20 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
let before = bool_op.values.iter().take(*first).cloned();
let after = bool_op.values.iter().skip(last + 1).cloned();

let comparator = if all_hashable {
Expr::Set(ast::ExprSet {
elts: comparators.iter().copied().cloned().collect(),
range: TextRange::default(),
})
} else {
Expr::Tuple(ast::ExprTuple {
elts: comparators.iter().copied().cloned().collect(),
range: TextRange::default(),
ctx: ExprContext::Load,
parenthesized: true,
})
};

diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
checker.generator().expr(&Expr::BoolOp(ast::ExprBoolOp {
op: bool_op.op,
Expand All @@ -150,12 +185,7 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
BoolOp::Or => Box::from([CmpOp::In]),
BoolOp::And => Box::from([CmpOp::NotIn]),
},
comparators: Box::from([Expr::Tuple(ast::ExprTuple {
elts: comparators.iter().copied().cloned().collect(),
range: TextRange::default(),
ctx: ExprContext::Load,
parenthesized: true,
})]),
comparators: Box::from([comparator]),
range: bool_op.range(),
})))
.chain(after)
Expand Down Expand Up @@ -231,11 +261,13 @@ fn to_allowed_value<'a>(
}

/// Generate a string like `obj in (a, b, c)` or `obj not in (a, b, c)`.
/// If `all_hashable` is `true`, the string will use a set instead of a tuple.
fn merged_membership_test(
left: &Expr,
op: BoolOp,
comparators: &[&Expr],
locator: &Locator,
all_hashable: bool,
) -> String {
let op = match op {
BoolOp::Or => "in",
Expand All @@ -246,5 +278,10 @@ fn merged_membership_test(
.iter()
.map(|comparator| locator.slice(comparator))
.join(", ");

if all_hashable {
return format!("{left} {op} {{{members}}}",);
}

format!("{left} {op} ({members})",)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
---
source: crates/ruff_linter/src/rules/pylint/mod.rs
snapshot_kind: text
---
repeated_equality_comparison.py:2:1: PLR1714 [*] Consider merging multiple comparisons: `foo in ("a", "b")`. Use a `set` if the elements are hashable.
|
Expand Down Expand Up @@ -375,3 +374,82 @@ repeated_equality_comparison.py:65:16: PLR1714 [*] Consider merging multiple com
65 |+foo == "a" or (bar not in ("c", "d")) or foo == "b" # Multiple targets
66 66 |
67 67 | foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets
68 68 |

repeated_equality_comparison.py:69:1: PLR1714 [*] Consider merging multiple comparisons: `foo in (1, True)`. Use a `set` if the elements are hashable.
|
67 | foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets
68 |
69 | foo == 1 or foo == True # Different types, same hashed value
| ^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
70 |
71 | foo == 1 or foo == 1.0 # Different types, same hashed value
|
= help: Merge multiple comparisons

Unsafe fix
66 66 |
67 67 | foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets
68 68 |
69 |-foo == 1 or foo == True # Different types, same hashed value
69 |+foo in (1, True) # Different types, same hashed value
70 70 |
71 71 | foo == 1 or foo == 1.0 # Different types, same hashed value
72 72 |

repeated_equality_comparison.py:71:1: PLR1714 [*] Consider merging multiple comparisons: `foo in (1, 1.0)`. Use a `set` if the elements are hashable.
|
69 | foo == 1 or foo == True # Different types, same hashed value
70 |
71 | foo == 1 or foo == 1.0 # Different types, same hashed value
| ^^^^^^^^^^^^^^^^^^^^^^ PLR1714
72 |
73 | foo == False or foo == 0 # Different types, same hashed value
|
= help: Merge multiple comparisons

Unsafe fix
68 68 |
69 69 | foo == 1 or foo == True # Different types, same hashed value
70 70 |
71 |-foo == 1 or foo == 1.0 # Different types, same hashed value
71 |+foo in (1, 1.0) # Different types, same hashed value
72 72 |
73 73 | foo == False or foo == 0 # Different types, same hashed value
74 74 |

repeated_equality_comparison.py:73:1: PLR1714 [*] Consider merging multiple comparisons: `foo in (False, 0)`. Use a `set` if the elements are hashable.
|
71 | foo == 1 or foo == 1.0 # Different types, same hashed value
72 |
73 | foo == False or foo == 0 # Different types, same hashed value
| ^^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
74 |
75 | foo == 0.0 or foo == 0j # Different types, same hashed value
|
= help: Merge multiple comparisons

Unsafe fix
70 70 |
71 71 | foo == 1 or foo == 1.0 # Different types, same hashed value
72 72 |
73 |-foo == False or foo == 0 # Different types, same hashed value
73 |+foo in (False, 0) # Different types, same hashed value
74 74 |
75 75 | foo == 0.0 or foo == 0j # Different types, same hashed value

repeated_equality_comparison.py:75:1: PLR1714 [*] Consider merging multiple comparisons: `foo in (0.0, 0j)`. Use a `set` if the elements are hashable.
|
73 | foo == False or foo == 0 # Different types, same hashed value
74 |
75 | foo == 0.0 or foo == 0j # Different types, same hashed value
| ^^^^^^^^^^^^^^^^^^^^^^^ PLR1714
|
= help: Merge multiple comparisons

Unsafe fix
72 72 |
73 73 | foo == False or foo == 0 # Different types, same hashed value
74 74 |
75 |-foo == 0.0 or foo == 0j # Different types, same hashed value
75 |+foo in (0.0, 0j) # Different types, same hashed value
Loading

0 comments on commit 5776535

Please sign in to comment.