Skip to content

Commit

Permalink
add default scopes to through unions
Browse files Browse the repository at this point in the history
  • Loading branch information
ezekg committed Mar 27, 2024
1 parent d68fb2b commit af3e1a2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
42 changes: 30 additions & 12 deletions lib/union_of.rb
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ def next_chain_scope(scope, reflection, next_reflection)
table = klass.arel_table
foreign_klass = next_reflection.klass
foreign_table = foreign_klass.arel_table
constraints = []

# This holds our union's constraints. For example, if we're unioning across 3
# tables, then this will hold constraints for all 3 of those tables, so that
# the join on our target table mirrors the union of all 3 associations.
foreign_constraints = []

reflection.union_sources.each do |union_source|
union_reflection = foreign_klass.reflect_on_association(union_source)
Expand All @@ -231,21 +235,33 @@ def next_chain_scope(scope, reflection, next_reflection)
through_reflection.name,
)

constraints << foreign_table[through_reflection.join_foreign_key].eq(through_table[through_reflection.join_primary_key])
foreign_constraints << foreign_table[through_reflection.join_foreign_key].eq(through_table[through_reflection.join_primary_key])
else
constraints << foreign_table[union_reflection.join_foreign_key].eq(table[union_reflection.join_primary_key])
foreign_constraints << foreign_table[union_reflection.join_foreign_key].eq(table[union_reflection.join_primary_key])
end
end

# Flatten union constraints and add any default constraints
foreign_constraint = unless (where_clause = foreign_klass.default_scoped.where_clause).empty?
where_clause.ast.and(foreign_constraints.reduce(&:or))
else
foreign_constraints.reduce(&:or)
end

scope.joins!(
Arel::Nodes::InnerJoin.new(
foreign_table,
Arel::Nodes::On.new(
constraints.reduce(&:or),
foreign_constraint,
),
),
)

# FIXME(ezekg) Why is this needed? Should be handled automatically...
scope.merge!(
scope.default_scoped,
)

scope
end

Expand Down Expand Up @@ -278,7 +294,7 @@ def join_scope(table, foreign_table, foreign_klass, alias_tracker = nil)
# This holds our union's constraints. For example, if we're unioning across 3
# tables, then this will hold constraints for all 3 of those tables, so that
# the join on our target table mirrors the union of all 3 associations.
constraints = []
foreign_constraints = []

union_sources.each do |union_source|
union_reflection = foreign_klass.reflect_on_association(union_source)
Expand All @@ -297,29 +313,31 @@ def join_scope(table, foreign_table, foreign_klass, alias_tracker = nil)
end

# Create base join constraints and add default constraints if available
through_constraints = through_table[through_reflection.join_primary_key].eq(
through_constraint = through_table[through_reflection.join_primary_key].eq(
foreign_table[through_reflection.join_foreign_key],
)

unless (where_clause = through_klass.default_scoped.where_clause).empty?
through_constraints = where_clause.ast.and(through_constraints)
through_constraint = where_clause.ast.and(through_constraint)
end

klass_scope.joins!(
Arel::Nodes::OuterJoin.new(
through_table,
Arel::Nodes::On.new(through_constraints),
Arel::Nodes::On.new(through_constraint),
),
)

constraints << table[source_reflection.join_primary_key].eq(through_table[source_reflection.join_foreign_key])
foreign_constraints << table[source_reflection.join_primary_key].eq(through_table[source_reflection.join_foreign_key])
else
constraints << table[union_reflection.join_primary_key].eq(foreign_table[union_reflection.join_foreign_key])
foreign_constraints << table[union_reflection.join_primary_key].eq(foreign_table[union_reflection.join_foreign_key])
end
end

unless constraints.empty?
klass_scope.where!(constraints.reduce(&:or))
unless foreign_constraints.empty?
foreign_constraint = foreign_constraints.reduce(&:or)

klass_scope.where!(foreign_constraint)
end

unless scope_chain_items.empty?
Expand Down
2 changes: 2 additions & 0 deletions spec/lib/union_of_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@
)
WHERE
"licenses"."id" = '#{license.id}'
ORDER BY
"users"."created_at" ASC
SQL
end

Expand Down

0 comments on commit af3e1a2

Please sign in to comment.