Skip to content

Commit

Permalink
[pylint] - use sets when possible for PLR1714 autofix
Browse files Browse the repository at this point in the history
  • Loading branch information
diceroll123 committed Nov 15, 2024
1 parent c847cad commit d22ffd5
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@
foo == "a" or ("c" != bar and "d" != bar) or foo == "b" # Multiple targets

foo == "a" and "c" != bar or foo == "b" and "d" != bar # Multiple targets

foo == 1 or foo == True # Different types, same hashed value, Tuple expected

foo == 1 or foo == 1.0 # Different types, same hashed value, Tuple expected

foo == False or foo == 0 # Different types, same hashed value, Tuple expected

foo == 0.0 or foo == 0j # Different types, same hashed value, Tuple expected
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use itertools::Itertools;
use rustc_hash::{FxBuildHasher, FxHashMap};
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};

use ast::ExprContext;
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::comparable::ComparableExpr;
use ruff_python_ast::comparable::{ComparableExpr, HashableExpr};
use ruff_python_ast::helpers::{any_over_expr, contains_effect};
use ruff_python_ast::{self as ast, BoolOp, CmpOp, Expr};
use ruff_python_semantic::SemanticModel;
Expand Down Expand Up @@ -51,12 +51,9 @@ impl AlwaysFixableViolation for RepeatedEqualityComparison {
#[derive_message_formats]
fn message(&self) -> String {
if let Some(expression) = self.expression.full_display() {
format!(
"Consider merging multiple comparisons: `{expression}`. Use a `set` if the elements are hashable."
)
format!("Consider merging multiple comparisons: `{expression}`.")
} else {
"Consider merging multiple comparisons. Use a `set` if the elements are hashable."
.to_string()
"Consider merging multiple comparisons.".to_string()
}
}

Expand Down Expand Up @@ -121,13 +118,20 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
continue;
}

let mut seen_values =
FxHashSet::with_capacity_and_hasher(comparators.len(), FxBuildHasher);
let use_set = comparators
.iter()
.all(|comparator| seen_values.insert(HashableExpr::from(*comparator)));

let mut diagnostic = Diagnostic::new(
RepeatedEqualityComparison {
expression: SourceCodeSnippet::new(merged_membership_test(
expr,
bool_op.op,
&comparators,
checker.locator(),
use_set,
)),
},
bool_op.range(),
Expand All @@ -140,6 +144,20 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
let before = bool_op.values.iter().take(*first).cloned();
let after = bool_op.values.iter().skip(last + 1).cloned();

let comparator = if use_set {
Expr::Set(ast::ExprSet {
elts: comparators.iter().copied().cloned().collect(),
range: TextRange::default(),
})
} else {
Expr::Tuple(ast::ExprTuple {
elts: comparators.iter().copied().cloned().collect(),
range: TextRange::default(),
ctx: ExprContext::Load,
parenthesized: true,
})
};

diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
checker.generator().expr(&Expr::BoolOp(ast::ExprBoolOp {
op: bool_op.op,
Expand All @@ -150,12 +168,7 @@ pub(crate) fn repeated_equality_comparison(checker: &mut Checker, bool_op: &ast:
BoolOp::Or => Box::from([CmpOp::In]),
BoolOp::And => Box::from([CmpOp::NotIn]),
},
comparators: Box::from([Expr::Tuple(ast::ExprTuple {
elts: comparators.iter().copied().cloned().collect(),
range: TextRange::default(),
ctx: ExprContext::Load,
parenthesized: true,
})]),
comparators: Box::from([comparator]),
range: bool_op.range(),
})))
.chain(after)
Expand Down Expand Up @@ -231,11 +244,13 @@ fn to_allowed_value<'a>(
}

/// Generate a string like `obj in (a, b, c)` or `obj not in (a, b, c)`.
/// If `use_set` is `true`, the string will use a set instead of a tuple.
fn merged_membership_test(
left: &Expr,
op: BoolOp,
comparators: &[&Expr],
locator: &Locator,
use_set: bool,
) -> String {
let op = match op {
BoolOp::Or => "in",
Expand All @@ -246,5 +261,10 @@ fn merged_membership_test(
.iter()
.map(|comparator| locator.slice(comparator))
.join(", ");

if use_set {
return format!("{left} {op} {{{members}}}",);
}

format!("{left} {op} ({members})",)
}
Loading

0 comments on commit d22ffd5

Please sign in to comment.