Skip to content

Commit

Permalink
[red-knot] Simplify some traits in ast_ids.rs (#14379)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood authored Nov 16, 2024
1 parent a6a3d3f commit 81d3c41
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 65 deletions.
56 changes: 20 additions & 36 deletions crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,64 +49,50 @@ fn ast_ids<'db>(db: &'db dyn Db, scope: ScopeId) -> &'db AstIds {
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
}

pub trait HasScopedUseId {
/// The type of the ID uniquely identifying the use.
type Id: Copy;

/// Returns the ID that uniquely identifies the use in `scope`.
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id;
}

/// Uniquely identifies a use of a name in a [`crate::semantic_index::symbol::FileScopeId`].
#[newtype_index]
pub struct ScopedUseId;

impl HasScopedUseId for ast::ExprName {
type Id = ScopedUseId;
pub trait HasScopedUseId {
/// Returns the ID that uniquely identifies the use in `scope`.
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId;
}

fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
impl HasScopedUseId for ast::ExprName {
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId {
let expression_ref = ExpressionRef::from(self);
expression_ref.scoped_use_id(db, scope)
}
}

impl HasScopedUseId for ast::ExpressionRef<'_> {
type Id = ScopedUseId;

fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedUseId {
let ast_ids = ast_ids(db, scope);
ast_ids.use_id(*self)
}
}

pub trait HasScopedAstId {
/// The type of the ID uniquely identifying the node.
type Id: Copy;
/// Uniquely identifies an [`ast::Expr`] in a [`crate::semantic_index::symbol::FileScopeId`].
#[newtype_index]
pub struct ScopedExpressionId;

pub trait HasScopedExpressionId {
/// Returns the ID that uniquely identifies the node in `scope`.
fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id;
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId;
}

impl<T: HasScopedAstId> HasScopedAstId for Box<T> {
type Id = <T as HasScopedAstId>::Id;

fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
self.as_ref().scoped_ast_id(db, scope)
impl<T: HasScopedExpressionId> HasScopedExpressionId for Box<T> {
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId {
self.as_ref().scoped_expression_id(db, scope)
}
}

/// Uniquely identifies an [`ast::Expr`] in a [`crate::semantic_index::symbol::FileScopeId`].
#[newtype_index]
pub struct ScopedExpressionId;

macro_rules! impl_has_scoped_expression_id {
($ty: ty) => {
impl HasScopedAstId for $ty {
type Id = ScopedExpressionId;

fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
impl HasScopedExpressionId for $ty {
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId {
let expression_ref = ExpressionRef::from(self);
expression_ref.scoped_ast_id(db, scope)
expression_ref.scoped_expression_id(db, scope)
}
}
};
Expand Down Expand Up @@ -146,10 +132,8 @@ impl_has_scoped_expression_id!(ast::ExprSlice);
impl_has_scoped_expression_id!(ast::ExprIpyEscapeCommand);
impl_has_scoped_expression_id!(ast::Expr);

impl HasScopedAstId for ast::ExpressionRef<'_> {
type Id = ScopedExpressionId;

fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
impl HasScopedExpressionId for ast::ExpressionRef<'_> {
fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId {
let ast_ids = ast_ids(db, scope);
ast_ids.expression_id(*self)
}
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/semantic_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use ruff_source_file::LineIndex;

use crate::module_name::ModuleName;
use crate::module_resolver::{resolve_module, Module};
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::semantic_index;
use crate::types::{binding_ty, infer_scope_types, Type};
use crate::Db;
Expand Down Expand Up @@ -54,7 +54,7 @@ impl HasTy for ast::ExpressionRef<'_> {
let file_scope = index.expression_scope_id(*self);
let scope = file_scope.to_scope_id(model.db, model.file);

let expression_id = self.scoped_ast_id(model.db, scope);
let expression_id = self.scoped_expression_id(model.db, scope);
infer_scope_types(model.db, scope).expression_ty(expression_id)
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub(crate) use self::infer::{
};
pub(crate) use self::signatures::Signature;
use crate::module_resolver::file_to_module;
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{self as symbol, ScopeId, ScopedSymbolId};
use crate::semantic_index::{
Expand Down Expand Up @@ -207,7 +207,7 @@ fn definition_expression_ty<'db>(
let index = semantic_index(db, file);
let file_scope = index.expression_scope_id(expression);
let scope = file_scope.to_scope_id(db, file);
let expr_id = expression.scoped_ast_id(db, scope);
let expr_id = expression.scoped_expression_id(db, scope);
if scope == definition.scope(db) {
// expression is in the definition scope
let inference = infer_definition_types(db, definition);
Expand Down
34 changes: 16 additions & 18 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use salsa::plumbing::AsId;

use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{
AssignmentDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey,
ExceptHandlerDefinitionKind, TargetKind,
Expand Down Expand Up @@ -181,7 +181,7 @@ fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult
let scope = unpack.scope(db);

let result = infer_expression_types(db, value);
let value_ty = result.expression_ty(value.node_ref(db).scoped_ast_id(db, scope));
let value_ty = result.expression_ty(value.node_ref(db).scoped_expression_id(db, scope));

let mut unpacker = Unpacker::new(db, file);
unpacker.unpack(unpack.target(db), value_ty, scope);
Expand Down Expand Up @@ -409,7 +409,7 @@ impl<'db> TypeInferenceBuilder<'db> {
#[track_caller]
fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
self.types
.expression_ty(expr.scoped_ast_id(self.db, self.scope()))
.expression_ty(expr.scoped_expression_id(self.db, self.scope()))
}

/// Infers types in the given [`InferenceRegion`].
Expand Down Expand Up @@ -1215,9 +1215,10 @@ impl<'db> TypeInferenceBuilder<'db> {
is_async,
);

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope()), target_ty);
self.types.expressions.insert(
target.scoped_expression_id(self.db, self.scope()),
target_ty,
);
self.add_binding(target.into(), definition, target_ty);
}

Expand Down Expand Up @@ -1607,7 +1608,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_standalone_expression(value);

let value_ty = self.expression_ty(value);
let name_ast_id = name.scoped_ast_id(self.db, self.scope());
let name_ast_id = name.scoped_expression_id(self.db, self.scope());

let target_ty = match assignment.target() {
TargetKind::Sequence(unpack) => {
Expand Down Expand Up @@ -2211,18 +2212,14 @@ impl<'db> TypeInferenceBuilder<'db> {
ty
}

fn store_expression_type(
&mut self,
expression: &impl HasScopedAstId<Id = ScopedExpressionId>,
ty: Type<'db>,
) {
fn store_expression_type(&mut self, expression: &impl HasScopedExpressionId, ty: Type<'db>) {
if self.deferred_state.in_string_annotation() {
// Avoid storing the type of expressions that are part of a string annotation because
// the expression ids don't exists in the semantic index. Instead, we'll store the type
// on the string expression itself that represents the annotation.
return;
}
let expr_id = expression.scoped_ast_id(self.db, self.scope());
let expr_id = expression.scoped_expression_id(self.db, self.scope());
let previous = self.types.expressions.insert(expr_id, ty);
assert_eq!(previous, None);
}
Expand Down Expand Up @@ -2541,10 +2538,10 @@ impl<'db> TypeInferenceBuilder<'db> {
.parent_scope_id(self.scope().file_scope_id(self.db))
.expect("A comprehension should never be the top-level scope")
.to_scope_id(self.db, self.file);
result.expression_ty(iterable.scoped_ast_id(self.db, lookup_scope))
result.expression_ty(iterable.scoped_expression_id(self.db, lookup_scope))
} else {
self.extend(result);
result.expression_ty(iterable.scoped_ast_id(self.db, self.scope()))
result.expression_ty(iterable.scoped_expression_id(self.db, self.scope()))
};

let target_ty = if is_async {
Expand All @@ -2556,9 +2553,10 @@ impl<'db> TypeInferenceBuilder<'db> {
.unwrap_with_diagnostic(iterable.into(), &mut self.diagnostics)
};

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope()), target_ty);
self.types.expressions.insert(
target.scoped_expression_id(self.db, self.scope()),
target_ty,
);
self.add_binding(target.into(), definition, target_ty);
}

Expand Down
10 changes: 5 additions & 5 deletions crates/red_knot_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::ast_ids::HasScopedExpressionId;
use crate::semantic_index::constraint::{Constraint, ConstraintNode, PatternConstraint};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
Expand Down Expand Up @@ -291,7 +291,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
{
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let rhs_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));

match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
Expand Down Expand Up @@ -336,7 +336,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
// TODO: add support for PEP 604 union types on the right hand side of `isinstance`
// and `issubclass`, for example `isinstance(x, str | (int | float))`.
match inference
.expression_ty(expr_call.func.scoped_ast_id(self.db, scope))
.expression_ty(expr_call.func.scoped_expression_id(self.db, scope))
.into_function_literal()
.and_then(|f| f.known(self.db))
.and_then(KnownFunction::constraint_function)
Expand All @@ -348,7 +348,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
let symbol = self.symbols().symbol_id_by_name(id).unwrap();

let class_info_ty =
inference.expression_ty(class_info.scoped_ast_id(self.db, scope));
inference.expression_ty(class_info.scoped_expression_id(self.db, scope));

let to_constraint = match function {
KnownConstraintFunction::IsInstance => {
Expand Down Expand Up @@ -414,7 +414,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
// filter our arms with statically known truthiness
.filter(|expr| {
inference
.expression_ty(expr.scoped_ast_id(self.db, scope))
.expression_ty(expr.scoped_expression_id(self.db, scope))
.bool(self.db)
!= match expr_bool_op.op {
BoolOp::And => Truthiness::AlwaysTrue,
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types/unpacker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ruff_db::files::File;
use ruff_python_ast::{self as ast, AnyNodeRef};
use rustc_hash::FxHashMap;

use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId};
use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId};
use crate::semantic_index::symbol::ScopeId;
use crate::types::{Type, TypeCheckDiagnostics, TypeCheckDiagnosticsBuilder};
use crate::Db;
Expand All @@ -29,7 +29,7 @@ impl<'db> Unpacker<'db> {
match target {
ast::Expr::Name(target_name) => {
self.targets
.insert(target_name.scoped_ast_id(self.db, scope), value_ty);
.insert(target_name.scoped_expression_id(self.db, scope), value_ty);
}
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
self.unpack(value, value_ty, scope);
Expand Down

0 comments on commit 81d3c41

Please sign in to comment.