From cdb7f7c3c7956ae63f2145e8cc940daaa6abc821 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 4 Sep 2023 11:29:21 +0200 Subject: [PATCH 1/9] 1601 Single Query Loading with where clause - Prepare branch --- pom.xml | 2 +- spring-data-jdbc-distribution/pom.xml | 2 +- spring-data-jdbc/pom.xml | 4 ++-- spring-data-r2dbc/pom.xml | 4 ++-- spring-data-relational/pom.xml | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pom.xml b/pom.xml index cbbdd63bc2..400d09babe 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-relational-parent - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT pom Spring Data Relational Parent diff --git a/spring-data-jdbc-distribution/pom.xml b/spring-data-jdbc-distribution/pom.xml index 271486f02a..738f08166e 100644 --- a/spring-data-jdbc-distribution/pom.xml +++ b/spring-data-jdbc-distribution/pom.xml @@ -14,7 +14,7 @@ org.springframework.data spring-data-relational-parent - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT ../pom.xml diff --git a/spring-data-jdbc/pom.xml b/spring-data-jdbc/pom.xml index c2f44d3f96..7d5e900055 100644 --- a/spring-data-jdbc/pom.xml +++ b/spring-data-jdbc/pom.xml @@ -6,7 +6,7 @@ 4.0.0 spring-data-jdbc - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT Spring Data JDBC Spring Data module for JDBC repositories. @@ -15,7 +15,7 @@ org.springframework.data spring-data-relational-parent - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT diff --git a/spring-data-r2dbc/pom.xml b/spring-data-r2dbc/pom.xml index a60f8e183a..f46f5fbc29 100644 --- a/spring-data-r2dbc/pom.xml +++ b/spring-data-r2dbc/pom.xml @@ -6,7 +6,7 @@ 4.0.0 spring-data-r2dbc - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC @@ -15,7 +15,7 @@ org.springframework.data spring-data-relational-parent - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT diff --git a/spring-data-relational/pom.xml b/spring-data-relational/pom.xml index 74f350faa8..1741efa6c2 100644 --- a/spring-data-relational/pom.xml +++ b/spring-data-relational/pom.xml @@ -6,7 +6,7 @@ 4.0.0 spring-data-relational - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT Spring Data Relational Spring Data Relational support @@ -14,7 +14,7 @@ org.springframework.data spring-data-relational-parent - 3.2.0-SNAPSHOT + 3.2.0-1601-where-clause-SNAPSHOT From 95fb0c4b8374cf49bfac52dc0b31caa5c60198fc Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Wed, 13 Sep 2023 16:01:48 +0200 Subject: [PATCH 2/9] Removed dialect dependency from QueryMapper. The refactoring does not honour the distinction between the default escaper and the like escaper, but since ValueFunctions are only used in the context of LIKE operations I don't see the point. See #1601 --- .../data/jdbc/core/convert/QueryMapper.java | 39 +++++++------- .../data/jdbc/core/convert/SqlGenerator.java | 2 +- .../query/EscapingParameterSource.java | 53 +++++++++++++++++++ .../repository/query/JdbcQueryCreator.java | 2 +- .../repository/query/ParametrizedQuery.java | 12 +++-- .../repository/query/PartTreeJdbcQuery.java | 4 +- .../core/convert/QueryMapperUnitTests.java | 2 +- .../query/PartTreeJdbcQueryUnitTests.java | 33 ++++++------ 8 files changed, 103 insertions(+), 44 deletions(-) create mode 100644 spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java index 65d9d3c231..b4dcd05b64 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import org.springframework.data.domain.Sort; import org.springframework.data.jdbc.core.mapping.JdbcValue; @@ -60,7 +61,6 @@ public class QueryMapper { private final JdbcConverter converter; - private final Dialect dialect; private final MappingContext, RelationalPersistentProperty> mappingContext; /** @@ -68,18 +68,31 @@ public class QueryMapper { * * @param dialect must not be {@literal null}. * @param converter must not be {@literal null}. + * @deprecated use {@link QueryMapper(JdbcConverter)} instead. */ - @SuppressWarnings({ "unchecked", "rawtypes" }) + @Deprecated(since="3.2") public QueryMapper(Dialect dialect, JdbcConverter converter) { Assert.notNull(dialect, "Dialect must not be null"); Assert.notNull(converter, "JdbcConverter must not be null"); this.converter = converter; - this.dialect = dialect; this.mappingContext = (MappingContext) converter.getMappingContext(); } + /** + * Creates a new {@link QueryMapper} with the given {@link JdbcConverter}. + * + * @param converter must not be {@literal null}. + */ + public QueryMapper( JdbcConverter converter) { + + Assert.notNull(converter, "JdbcConverter must not be null"); + + this.converter = converter; + this.mappingContext = converter.getMappingContext(); + } + /** * Map the {@link Sort} object to apply field name mapping using {@link RelationalPersistentEntity the type to read}. * @@ -295,16 +308,13 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc mappedValue = convertValue(comparator, settableValue.getValue(), propertyField.getTypeHint()); sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue); - } else if (criteria.getValue() instanceof ValueFunction) { + } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - ValueFunction valueFunction = (ValueFunction) criteria.getValue(); - Object value = valueFunction.apply(getEscaper(comparator)); - - mappedValue = convertValue(comparator, value, propertyField.getTypeHint()); + mappedValue = valueFunction; sqlType = propertyField.getSqlType(); - } else if (propertyField instanceof MetadataBackedField // - && ((MetadataBackedField) propertyField).property != null // + } else if (propertyField instanceof MetadataBackedField metadataBackedField // + && metadataBackedField.property != null // && (criteria.getValue() == null || !criteria.getValue().getClass().isArray())) { RelationalPersistentProperty property = ((MetadataBackedField) propertyField).property; @@ -431,15 +441,6 @@ private Condition mapEmbeddedObjectCondition(CriteriaDefinition criteria, MapSql return Conditions.nest(condition); } - private Escaper getEscaper(Comparator comparator) { - - if (comparator == Comparator.LIKE || comparator == Comparator.NOT_LIKE) { - return dialect.getLikeEscaper(); - } - - return Escaper.DEFAULT; - } - @Nullable private Object convertValue(Comparator comparator, @Nullable Object value, TypeInformation typeHint) { diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java index 60e69edec7..3058106226 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java @@ -113,7 +113,7 @@ class SqlGenerator { this.renderContext = new RenderContextFactory(dialect).createRenderContext(); this.sqlRenderer = SqlRenderer.create(renderContext); this.columns = new Columns(entity, mappingContext, converter); - this.queryMapper = new QueryMapper(dialect, converter); + this.queryMapper = new QueryMapper(converter); this.dialect = dialect; } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java new file mode 100644 index 0000000000..abc85b3d11 --- /dev/null +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.data.jdbc.repository.query; + +import org.springframework.data.relational.core.dialect.Escaper; +import org.springframework.data.relational.core.query.ValueFunction; +import org.springframework.jdbc.core.namedparam.SqlParameterSource; + +/** + * This {@link SqlParameterSource} will apply escaping to it's values. + * + * @author Jens Schauder + * @since 3.2 + */ +public class EscapingParameterSource implements SqlParameterSource { + private final SqlParameterSource parameterSource; + private final Escaper escaper; + + public EscapingParameterSource(SqlParameterSource parameterSource, Escaper escaper) { + + this.parameterSource = parameterSource; + this.escaper = escaper; + } + + @Override + public boolean hasValue(String paramName) { + return parameterSource.hasValue(paramName); + } + + @Override + public Object getValue(String paramName) throws IllegalArgumentException { + + Object value = parameterSource.getValue(paramName); + if (value instanceof ValueFunction) { + return ((ValueFunction) value).apply(escaper); + } + return value; + } +} diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java index b1cc21571b..7411036ab5 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/JdbcQueryCreator.java @@ -102,7 +102,7 @@ class JdbcQueryCreator extends RelationalQueryCreator { this.accessor = accessor; this.entityMetadata = entityMetadata; - this.queryMapper = new QueryMapper(dialect, converter); + this.queryMapper = new QueryMapper(converter); this.renderContextFactory = new RenderContextFactory(dialect); this.isSliceQuery = isSliceQuery; this.returnedType = returnedType; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java index b41e2f87f5..ac3c256d27 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java @@ -15,12 +15,15 @@ */ package org.springframework.data.jdbc.repository.query; +import org.springframework.data.relational.core.dialect.Dialect; +import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.jdbc.core.namedparam.SqlParameterSource; /** * Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the parameters. * * @author Mark Paluch + * @author Jens Schauder * @since 2.0 */ class ParametrizedQuery { @@ -38,12 +41,13 @@ String getQuery() { return query; } - SqlParameterSource getParameterSource() { - return parameterSource; - } - @Override public String toString() { return this.query; } + + public SqlParameterSource getParameterSource(Escaper escaper) { + + return new EscapingParameterSource(parameterSource, escaper); + } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java index 01876f0d66..ecdbbc5152 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQuery.java @@ -126,7 +126,7 @@ public Object execute(Object[] values) { ParametrizedQuery query = createQuery(accessor, processor.getReturnedType()); JdbcQueryExecution execution = getQueryExecution(processor, accessor); - return execution.execute(query.getQuery(), query.getParameterSource()); + return execution.execute(query.getQuery(), query.getParameterSource(dialect.getLikeEscaper())); } private JdbcQueryExecution getQueryExecution(ResultProcessor processor, @@ -164,7 +164,7 @@ private JdbcQueryExecution getQueryExecution(ResultProcessor processor, ParametrizedQuery countQuery = queryCreator.createQuery(Sort.unsorted()); Object count = singleObjectQuery((rs, i) -> rs.getLong(1)).execute(countQuery.getQuery(), - countQuery.getParameterSource()); + countQuery.getParameterSource(dialect.getLikeEscaper())); return converter.getConversionService().convert(count, Long.class); }); diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/QueryMapperUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/QueryMapperUnitTests.java index 185c6d24c5..d2526fc9f2 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/QueryMapperUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/QueryMapperUnitTests.java @@ -50,7 +50,7 @@ public class QueryMapperUnitTests { JdbcMappingContext context = new JdbcMappingContext(); JdbcConverter converter = new BasicJdbcConverter(context, mock(RelationResolver.class)); - QueryMapper mapper = new QueryMapper(PostgresDialect.INSTANCE, converter); + QueryMapper mapper = new QueryMapper(converter); MapSqlParameterSource parameterSource = new MapSqlParameterSource(); @Test // DATAJDBC-318 diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java index 74b92b47a8..8c13c68de6 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/PartTreeJdbcQueryUnitTests.java @@ -36,6 +36,7 @@ import org.springframework.data.jdbc.core.mapping.AggregateReference; import org.springframework.data.jdbc.core.mapping.JdbcMappingContext; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.data.relational.core.dialect.H2Dialect; import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.mapping.MappedCollection; @@ -93,7 +94,7 @@ public void createQueryByAggregateReference() throws Exception { softly.assertThat(query.getQuery()) .isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"HOBBY_REFERENCE\" = :hobby_reference"); - softly.assertThat(query.getParameterSource().getValue("hobby_reference")).isEqualTo("twentythree"); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("hobby_reference")).isEqualTo("twentythree"); }); } @@ -112,8 +113,8 @@ void createQueryWithPessimisticWriteLock() throws Exception { softly.assertThat(query.getQuery().toUpperCase()).endsWith("FOR UPDATE"); - softly.assertThat(query.getParameterSource().getValue("first_name")).isEqualTo(firstname); - softly.assertThat(query.getParameterSource().getValue("last_name")).isEqualTo(lastname); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("first_name")).isEqualTo(firstname); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("last_name")).isEqualTo(lastname); }); } @@ -133,8 +134,8 @@ void createQueryWithPessimisticReadLock() throws Exception { // this is also for update since h2 dialect does not distinguish between lockmodes softly.assertThat(query.getQuery().toUpperCase()).endsWith("FOR UPDATE"); - softly.assertThat(query.getParameterSource().getValue("first_name")).isEqualTo(firstname); - softly.assertThat(query.getParameterSource().getValue("age")).isEqualTo(age); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("first_name")).isEqualTo(firstname); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("age")).isEqualTo(age); }); } @@ -165,7 +166,7 @@ public void createQueryForQueryByAggregateReference() throws Exception { softly.assertThat(query.getQuery()) .isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"HOBBY_REFERENCE\" = :hobby_reference"); - softly.assertThat(query.getParameterSource().getValue("hobby_reference")).isEqualTo("twentythree"); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("hobby_reference")).isEqualTo("twentythree"); }); } @@ -182,7 +183,7 @@ public void createQueryForQueryByAggregateReferenceId() throws Exception { softly.assertThat(query.getQuery()) .isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"HOBBY_REFERENCE\" = :hobby_reference"); - softly.assertThat(query.getParameterSource().getValue("hobby_reference")).isEqualTo("twentythree"); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("hobby_reference")).isEqualTo("twentythree"); }); } @@ -270,8 +271,8 @@ public void createsQueryToFindAllEntitiesByDateAttributeBetween() throws Excepti softly.assertThat(query.getQuery()) .isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"DATE_OF_BIRTH\" BETWEEN :date_of_birth AND :date_of_birth1"); - softly.assertThat(query.getParameterSource().getValue("date_of_birth")).isEqualTo(from); - softly.assertThat(query.getParameterSource().getValue("date_of_birth1")).isEqualTo(to); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("date_of_birth")).isEqualTo(from); + softly.assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("date_of_birth1")).isEqualTo(to); }); } @@ -405,7 +406,7 @@ public void appendsLikeOperatorParameterWithPercentSymbolForStartingWithQuery() ParametrizedQuery query = jdbcQuery.createQuery(accessor, returnedType); assertThat(query.getQuery()).isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"FIRST_NAME\" LIKE :first_name"); - assertThat(query.getParameterSource().getValue("first_name")).isEqualTo("Jo%"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("first_name")).isEqualTo("Jo%"); } @Test // DATAJDBC-318 @@ -428,7 +429,7 @@ public void prependsLikeOperatorParameterWithPercentSymbolForEndingWithQuery() t ParametrizedQuery query = jdbcQuery.createQuery(accessor, returnedType); assertThat(query.getQuery()).isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"FIRST_NAME\" LIKE :first_name"); - assertThat(query.getParameterSource().getValue("first_name")).isEqualTo("%hn"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("first_name")).isEqualTo("%hn"); } @Test // DATAJDBC-318 @@ -451,7 +452,7 @@ public void wrapsLikeOperatorParameterWithPercentSymbolsForContainingQuery() thr ParametrizedQuery query = jdbcQuery.createQuery(accessor, returnedType); assertThat(query.getQuery()).isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"FIRST_NAME\" LIKE :first_name"); - assertThat(query.getParameterSource().getValue("first_name")).isEqualTo("%oh%"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("first_name")).isEqualTo("%oh%"); } @Test // DATAJDBC-318 @@ -474,7 +475,7 @@ public void wrapsLikeOperatorParameterWithPercentSymbolsForNotContainingQuery() ParametrizedQuery query = jdbcQuery.createQuery(accessor, returnedType); assertThat(query.getQuery()).isEqualTo(BASE_SELECT + " WHERE " + TABLE + ".\"FIRST_NAME\" NOT LIKE :first_name"); - assertThat(query.getParameterSource().getValue("first_name")).isEqualTo("%oh%"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("first_name")).isEqualTo("%oh%"); } @Test // DATAJDBC-318 @@ -638,8 +639,8 @@ public void createsQueryByEmbeddedObject() throws Exception { .contains(TABLE + ".\"USER_STREET\" = :user_street", // " AND ", // TABLE + ".\"USER_CITY\" = :user_city"); - assertThat(query.getParameterSource().getValue("user_street")).isEqualTo("Hello"); - assertThat(query.getParameterSource().getValue("user_city")).isEqualTo("World"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("user_street")).isEqualTo("Hello"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("user_city")).isEqualTo("World"); } @Test // DATAJDBC-318 @@ -653,7 +654,7 @@ public void createsQueryByEmbeddedProperty() throws Exception { String expectedSql = BASE_SELECT + " WHERE " + TABLE + ".\"USER_STREET\" = :user_street"; assertThat(query.getQuery()).isEqualTo(expectedSql); - assertThat(query.getParameterSource().getValue("user_street")).isEqualTo("Hello"); + assertThat(query.getParameterSource(Escaper.DEFAULT).getValue("user_street")).isEqualTo("Hello"); } @Test // DATAJDBC-534 From 1ddb157e1599c609af8dbc156117d2fe5a2078f7 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 4 Sep 2023 14:15:33 +0200 Subject: [PATCH 3/9] Add support for arbitrary where clauses in Single Query Loading. Closes #1601 --- .../jdbc/core/convert/AggregateReader.java | 45 ++++++++- .../SingleQueryDataAccessStrategy.java | 5 +- ...SingleQueryFallbackDataAccessStrategy.java | 38 ++++++-- ...JdbcAggregateTemplateIntegrationTests.java | 92 +++++++++++++++++-- ...cAggregateTemplateIntegrationTests-db2.sql | 10 +- ...bcAggregateTemplateIntegrationTests-h2.sql | 8 ++ ...AggregateTemplateIntegrationTests-hsql.sql | 7 ++ ...regateTemplateIntegrationTests-mariadb.sql | 7 ++ ...ggregateTemplateIntegrationTests-mssql.sql | 8 ++ ...ggregateTemplateIntegrationTests-mysql.sql | 9 +- ...gregateTemplateIntegrationTests-oracle.sql | 8 ++ ...egateTemplateIntegrationTests-postgres.sql | 11 ++- .../SingleQuerySqlGenerator.java | 15 +-- .../core/sqlgeneration/SqlGenerator.java | 8 ++ 14 files changed, 244 insertions(+), 27 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index 77813ea9b3..715ba423f2 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -21,15 +21,22 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.mapping.AggregatePath; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.query.CriteriaDefinition; +import org.springframework.data.relational.core.query.Query; +import org.springframework.data.relational.core.sql.Condition; +import org.springframework.data.relational.core.sql.Table; import org.springframework.data.relational.core.sqlgeneration.AliasFactory; import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator; import org.springframework.data.relational.core.sqlgeneration.SqlGenerator; import org.springframework.data.relational.domain.RowDocument; +import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -89,6 +96,35 @@ public Iterable findAllById(Iterable ids) { return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll); } + public Iterable findAllBy(Query query) { + + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + BiFunction condition = createConditionSource(query, parameterSource); + return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll); + } + + public Optional findOneByQuery(Query query) { + + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + BiFunction condition = createConditionSource(query, parameterSource); + + return Optional.ofNullable( + jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne)); + } + + private BiFunction createConditionSource(Query query, MapSqlParameterSource parameterSource) { + + QueryMapper queryMapper = new QueryMapper(converter); + + BiFunction condition = (table, aggregate) -> { + Optional criteria = query.getCriteria(); + return criteria + .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) + .orElse(null); + }; + return condition; + } + /** * Extracts a list of aggregates from the given {@link ResultSet} by utilizing the * {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms @@ -115,7 +151,8 @@ private List extractAll(ResultSet rs) throws SQLException { * to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract. * * @param @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. - * @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is empty, null is returned. + * @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is + * empty, null is returned. * @throws SQLException * @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance. */ @@ -190,9 +227,15 @@ public String findAllById() { return findAllById; } + @Override + public String findAllByCondition(BiFunction conditionSource) { + return delegate.findAllByCondition(conditionSource); + } + @Override public AliasFactory getAliasFactory() { return delegate.getAliasFactory(); } + } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index 3f43d0652e..c0e20c425f 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -77,12 +77,13 @@ public Iterable findAll(Class domainType, Pageable pageable) { @Override public Optional findOne(Query query, Class domainType) { - return Optional.empty(); + return getReader(domainType).findOneByQuery(query); } @Override public Iterable findAll(Query query, Class domainType) { - throw new UnsupportedOperationException(); + + return getReader(domainType).findAllBy(query); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java index bc93cd09dd..0cb0b04638 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java @@ -16,9 +16,11 @@ package org.springframework.data.jdbc.core.convert; import java.util.Collections; +import java.util.Optional; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.query.Query; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.util.Assert; @@ -85,13 +87,37 @@ public Iterable findAllById(Iterable ids, Class domainType) { return super.findAllById(ids, domainType); } + public Optional findOne(Query query, Class domainType) { + + if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) { + return singleSelectDelegate.findOne(query, domainType); + } + + return super.findOne(query, domainType); + } + + @Override + public Iterable findAll(Query query, Class domainType) { + + if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) { + return singleSelectDelegate.findAll(query, domainType); + } + + return super.findAll(query, domainType); + } + + private static boolean isSingleSelectQuerySupported(Query query) { + return !query.isSorted() && !query.isLimited(); + } + private boolean isSingleSelectQuerySupported(Class entityType) { - return sqlGeneratorSource.getDialect().supportsSingleQueryLoading()// - && entityQualifiesForSingleSelectQuery(entityType); + return converter.getMappingContext().isSingleQueryLoadingEnabled() + && sqlGeneratorSource.getDialect().supportsSingleQueryLoading()// + && entityQualifiesForSingleQueryLoading(entityType); } - private boolean entityQualifiesForSingleSelectQuery(Class entityType) { + private boolean entityQualifiesForSingleQueryLoading(Class entityType) { boolean referenceFound = false; for (PersistentPropertyPath path : converter.getMappingContext() @@ -113,9 +139,9 @@ private boolean entityQualifiesForSingleSelectQuery(Class entityType) { } // AggregateReferences aren't supported yet - if (property.isAssociation()) { - return false; - } + // if (property.isAssociation()) { + // return false; + // } } return true; diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java index aa790fc854..e76825b29e 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java @@ -23,15 +23,8 @@ import static org.springframework.data.jdbc.testing.TestDatabaseFeatures.Feature.*; import java.time.LocalDateTime; +import java.util.*; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; import java.util.function.Function; import java.util.stream.IntStream; @@ -42,6 +35,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; +import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.annotation.Id; @@ -64,6 +58,9 @@ import org.springframework.data.relational.core.mapping.MappedCollection; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; +import org.springframework.data.relational.core.query.Criteria; +import org.springframework.data.relational.core.query.CriteriaDefinition; +import org.springframework.data.relational.core.query.Query; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.test.context.ActiveProfiles; @@ -223,6 +220,62 @@ void findAllById() { .containsExactlyInAnyOrder(tuple(entity.id, "entity"), tuple(yetAnother.id, "yetAnother")); } + @Test // GH-1601 + void findAllByQuery() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id)); + Query query = Query.query(criteria); + Iterable reloadedById = template.findAll(query, SimpleListParent.class); + + assertThat(reloadedById).extracting(e -> e.id, e -> e.content.size()).containsExactly(tuple(two.id, 2)); + } + + @Test // GH-1601 + void findOneByQuery() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(two.id)); + Query query = Query.query(criteria); + Optional reloadedById = template.findOne(query, SimpleListParent.class); + + assertThat(reloadedById).get().extracting(e -> e.id, e -> e.content.size()).containsExactly(two.id, 2); + } + + @Test // GH-1601 + void findOneByQueryNothingFound() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").is(4711)); + Query query = Query.query(criteria); + Optional reloadedById = template.findOne(query, SimpleListParent.class); + + assertThat(reloadedById).isEmpty(); + } + + @Test // GH-1601 + void findOneByQueryToManyResults() { + + template.save(SimpleListParent.of("one", "one_1")); + SimpleListParent two = template.save(SimpleListParent.of("two", "two_1", "two_2")); + template.save(SimpleListParent.of("three", "three_1", "three_2", "three_3")); + + CriteriaDefinition criteria = CriteriaDefinition.from(Criteria.where("id").not(two.id)); + Query query = Query.query(criteria); + + assertThatExceptionOfType(IncorrectResultSizeDataAccessException.class) + .isThrownBy(() -> template.findOne(query, SimpleListParent.class)); + } + @Test // DATAJDBC-112 @EnabledOnFeature(SUPPORTS_QUOTED_IDS) void saveAndLoadAnEntityWithReferencedEntityById() { @@ -1266,6 +1319,29 @@ static class ChildNoId { private String content; } + @SuppressWarnings("unused") + static class SimpleListParent { + + @Id private Long id; + String name; + List content = new ArrayList<>(); + + static SimpleListParent of(String name, String... contents) { + + SimpleListParent parent = new SimpleListParent(); + parent.name = name; + + for (String content : contents) { + + ElementNoId element = new ElementNoId(); + element.content = content; + parent.content.add(element); + } + + return parent; + } + } + @Table("LIST_PARENT") @SuppressWarnings("unused") static class ListParent { diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql index 8ad4fda2dc..f086a03b5c 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-db2.sql @@ -6,6 +6,7 @@ DROP TABLE ONE_TO_ONE_PARENT; DROP TABLE ELEMENT_NO_ID; DROP TABLE LIST_PARENT; +DROP TABLE SIMPLE_LIST_PARENT; DROP TABLE BYTE_ARRAY_OWNER; @@ -74,11 +75,18 @@ CREATE TABLE LIST_PARENT "id4" BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE ELEMENT_NO_ID ( CONTENT VARCHAR(100), LIST_PARENT_KEY BIGINT, - LIST_PARENT BIGINT + SIMPLE_LIST_PARENT_KEY BIGINT, + LIST_PARENT BIGINT, + SIMPLE_LIST_PARENT BIGINT ); ALTER TABLE ELEMENT_NO_ID ADD FOREIGN KEY (LIST_PARENT) diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql index a0aff08ce8..a6e5eabad7 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-h2.sql @@ -32,9 +32,17 @@ CREATE TABLE LIST_PARENT NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID SERIAL PRIMARY KEY, + NAME VARCHAR(100) +); + CREATE TABLE element_no_id ( content VARCHAR(100), + SIMPLE_LIST_PARENT_key BIGINT, + SIMPLE_LIST_PARENT INTEGER, LIST_PARENT_key BIGINT, LIST_PARENT INTEGER ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql index 4dd1294ab2..dc73899207 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-hsql.sql @@ -26,6 +26,11 @@ CREATE TABLE Child_No_Id content VARCHAR(30) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE LIST_PARENT ( "id4" BIGINT GENERATED BY DEFAULT AS IDENTITY ( START WITH 1 ) PRIMARY KEY, @@ -34,6 +39,8 @@ CREATE TABLE LIST_PARENT CREATE TABLE ELEMENT_NO_ID ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_KEY BIGINT, + SIMPLE_LIST_PARENT BIGINT, LIST_PARENT_KEY BIGINT, LIST_PARENT BIGINT ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql index 4dd82b9003..4258e7b438 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mariadb.sql @@ -31,9 +31,16 @@ CREATE TABLE LIST_PARENT `id4` BIGINT AUTO_INCREMENT PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE element_no_id ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_key BIGINT, + SIMPLE_LIST_PARENT BIGINT, LIST_PARENT_key BIGINT, LIST_PARENT BIGINT ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql index 880528cdbf..e9a378f49b 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mssql.sql @@ -30,14 +30,22 @@ CREATE TABLE Child_No_Id DROP TABLE IF EXISTS element_no_id; DROP TABLE IF EXISTS LIST_PARENT; +DROP TABLE IF EXISTS SIMPLE_LIST_PARENT; CREATE TABLE LIST_PARENT ( [id4] BIGINT IDENTITY PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT IDENTITY PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE element_no_id ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_key BIGINT, + SIMPLE_LIST_PARENT BIGINT, LIST_PARENT_key BIGINT, LIST_PARENT BIGINT ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql index 6808c8a912..40e32f1692 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-mysql.sql @@ -26,6 +26,11 @@ CREATE TABLE Child_No_Id `content` VARCHAR(30) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID BIGINT AUTO_INCREMENT PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE LIST_PARENT ( `id4` BIGINT AUTO_INCREMENT PRIMARY KEY, @@ -35,7 +40,9 @@ CREATE TABLE element_no_id ( CONTENT VARCHAR(100), LIST_PARENT_key BIGINT, - LIST_PARENT BIGINT + SIMPLE_LIST_PARENT_key BIGINT, + LIST_PARENT BIGINT, + SIMPLE_LIST_PARENT BIGINT ); CREATE TABLE BYTE_ARRAY_OWNER diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql index 084e5db460..5a5c5baf40 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-oracle.sql @@ -4,6 +4,7 @@ DROP TABLE CHILD_NO_ID CASCADE CONSTRAINTS PURGE; DROP TABLE ONE_TO_ONE_PARENT CASCADE CONSTRAINTS PURGE; DROP TABLE ELEMENT_NO_ID CASCADE CONSTRAINTS PURGE; DROP TABLE LIST_PARENT CASCADE CONSTRAINTS PURGE; +DROP TABLE SIMPLE_LIST_PARENT CASCADE CONSTRAINTS PURGE; DROP TABLE BYTE_ARRAY_OWNER CASCADE CONSTRAINTS PURGE; DROP TABLE CHAIN0 CASCADE CONSTRAINTS PURGE; DROP TABLE CHAIN1 CASCADE CONSTRAINTS PURGE; @@ -64,9 +65,16 @@ CREATE TABLE LIST_PARENT "id4" NUMBER GENERATED by default on null as IDENTITY PRIMARY KEY, NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + ID NUMBER GENERATED by default on null as IDENTITY PRIMARY KEY, + NAME VARCHAR(100) +); CREATE TABLE element_no_id ( CONTENT VARCHAR(100), + SIMPLE_LIST_PARENT_key NUMBER, + SIMPLE_LIST_PARENT NUMBER, LIST_PARENT_key NUMBER, LIST_PARENT NUMBER ); diff --git a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql index 0c77c88139..d43b5750b1 100644 --- a/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql +++ b/spring-data-jdbc/src/test/resources/org.springframework.data.jdbc.core/JdbcAggregateTemplateIntegrationTests-postgres.sql @@ -4,6 +4,7 @@ DROP TABLE ONE_TO_ONE_PARENT; DROP TABLE Child_No_Id; DROP TABLE element_no_id; DROP TABLE "LIST_PARENT"; +DROP TABLE SIMPLE_LIST_PARENT; DROP TABLE "ARRAY_OWNER"; DROP TABLE DOUBLE_LIST_OWNER; DROP TABLE FLOAT_LIST_OWNER; @@ -68,11 +69,19 @@ CREATE TABLE "LIST_PARENT" NAME VARCHAR(100) ); +CREATE TABLE SIMPLE_LIST_PARENT +( + id SERIAL PRIMARY KEY, + NAME VARCHAR(100) +); + CREATE TABLE element_no_id ( content VARCHAR(100), LIST_PARENT_key BIGINT, - "LIST_PARENT" INTEGER + SIMPLE_LIST_PARENT_key BIGINT, + "LIST_PARENT" INTEGER, + SIMPLE_LIST_PARENT INTEGER ); CREATE TABLE "ARRAY_OWNER" diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java index 5bb11e4b81..505a027dca 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java @@ -19,6 +19,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import org.jetbrains.annotations.NotNull; import org.springframework.data.mapping.PersistentProperty; @@ -81,6 +82,12 @@ public String findAllById() { return createSelect(condition); } + @Override + public String findAllByCondition(BiFunction conditionSource) { + Condition condition = conditionSource.apply(table, aggregate); + return createSelect(condition); + } + /** * @return The {@link AggregatePath} to the id property of the aggregate root. */ @@ -88,13 +95,7 @@ private AggregatePath getRootIdPath() { return context.getAggregatePath(aggregate).append(aggregate.getRequiredIdProperty()); } - /** - * Creates a SQL suitable of loading all the data required for constructing complete aggregates. - * - * @param condition a constraint for limiting the aggregates to be loaded. - * @return a {@literal String} containing the generated SQL statement - */ - private String createSelect(Condition condition) { + String createSelect(Condition condition) { AggregatePath rootPath = context.getAggregatePath(aggregate); QueryMeta queryMeta = createInlineQuery(rootPath, condition); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java index 78049657e0..80eb9a1a87 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java @@ -15,6 +15,12 @@ */ package org.springframework.data.relational.core.sqlgeneration; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.sql.Condition; +import org.springframework.data.relational.core.sql.Table; + +import java.util.function.BiFunction; + /** * Generates SQL statements for loading aggregates. * @@ -28,5 +34,7 @@ public interface SqlGenerator { String findAllById(); + String findAllByCondition(BiFunction conditionSource); + AliasFactory getAliasFactory(); } From 1e63e31c120c872ff06fd630d479eebefc49539d Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 25 Sep 2023 10:09:37 +0200 Subject: [PATCH 4/9] change return value of AggregateReader to List --- .../data/jdbc/core/convert/AggregateReader.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index 715ba423f2..da1373adbd 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -86,7 +86,7 @@ public T findById(Object id) { return jdbcTemplate.query(sqlGenerator.findById(), Map.of("id", id), this::extractZeroOrOne); } - public Iterable findAllById(Iterable ids) { + public List findAllById(Iterable ids) { List convertedIds = new ArrayList<>(); for (Object id : ids) { @@ -96,7 +96,7 @@ public Iterable findAllById(Iterable ids) { return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll); } - public Iterable findAllBy(Query query) { + public List findAllBy(Query query) { MapSqlParameterSource parameterSource = new MapSqlParameterSource(); BiFunction condition = createConditionSource(query, parameterSource); From d43740e233492fb0b8dd19f7b23949cf86d93231 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 25 Sep 2023 10:26:12 +0200 Subject: [PATCH 5/9] introduced proper implementation of ValueFunction --- .../query/ParameterMetadataProvider.java | 7 ++-- .../repository/query/SimpleValueFunction.java | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java index 818fa8578f..00657d3ca8 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java @@ -20,7 +20,6 @@ import java.util.List; import org.springframework.data.relational.core.dialect.Escaper; -import org.springframework.data.relational.core.query.ValueFunction; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.parser.Part; @@ -139,12 +138,12 @@ protected Object prepareParameterValue(@Nullable Object value, Class valueTyp switch (partType) { case STARTING_WITH: - return (ValueFunction) escaper -> escaper.escape(value.toString()) + "%"; + return SimpleValueFunction.of(value, s -> s + "%"); case ENDING_WITH: - return (ValueFunction) escaper -> "%" + escaper.escape(value.toString()); + return SimpleValueFunction.of(value, s -> "%" + s); case CONTAINING: case NOT_CONTAINING: - return (ValueFunction) escaper -> "%" + escaper.escape(value.toString()) + "%"; + return SimpleValueFunction.of(value, s -> "%" + s + "%"); default: return value; } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java new file mode 100644 index 0000000000..9e32867eda --- /dev/null +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.data.relational.repository.query; + +import org.springframework.data.relational.core.dialect.Escaper; +import org.springframework.data.relational.core.query.ValueFunction; + +import java.util.function.Function; + +record SimpleValueFunction(Object value, Function modifier) implements ValueFunction { + + static SimpleValueFunction of(Object value, Function modifier) { + return new SimpleValueFunction(value, modifier); + } + + @Override + public String apply(Escaper escaper) { + return modifier.apply(escaper.escape(value.toString())); + } +} From 8fe6fdad24ad46bb6800e763d86514060b5beab6 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Mon, 25 Sep 2023 13:12:00 +0200 Subject: [PATCH 6/9] support conversion for ValueFunction --- .../data/jdbc/core/convert/QueryMapper.java | 2 +- .../query/EscapingParameterSource.java | 4 ++-- .../data/r2dbc/query/QueryMapper.java | 6 ++---- .../data/r2dbc/query/UpdateMapper.java | 10 +++------- .../relational/core/query/ValueFunction.java | 12 ++++++++++++ ...unction.java => ModifyingValueFunction.java} | 17 ++++++++++++++--- .../query/ParameterMetadataProvider.java | 6 +++--- 7 files changed, 37 insertions(+), 20 deletions(-) rename spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/{SimpleValueFunction.java => ModifyingValueFunction.java} (62%) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java index b4dcd05b64..547eac6716 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java @@ -310,7 +310,7 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue); } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - mappedValue = valueFunction; + mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())); sqlType = propertyField.getSqlType(); } else if (propertyField instanceof MetadataBackedField metadataBackedField // diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java index abc85b3d11..8f0fa6e818 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java @@ -45,8 +45,8 @@ public boolean hasValue(String paramName) { public Object getValue(String paramName) throws IllegalArgumentException { Object value = parameterSource.getValue(paramName); - if (value instanceof ValueFunction) { - return ((ValueFunction) value).apply(escaper); + if (value instanceof ValueFunction valueFunction) { + return valueFunction.apply(escaper); } return value; } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 3a6a2936ed..08adc38d74 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -373,12 +373,10 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind mappedValue = convertValue(comparator, parameter.getValue(), propertyField.getTypeHint()); typeHint = getTypeHint(mappedValue, actualType.getType(), parameter); - } else if (criteria.getValue() instanceof ValueFunction) { + } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - ValueFunction valueFunction = (ValueFunction) criteria.getValue(); - Object value = valueFunction.apply(getEscaper(comparator)); + mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())).apply(getEscaper(comparator)); - mappedValue = convertValue(comparator, value, propertyField.getTypeHint()); typeHint = actualType.getType(); } else { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index 770010dc31..b77d95bab9 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -111,18 +111,14 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable Object mappedValue; Class typeHint; - if (value instanceof Parameter) { - - Parameter parameter = (Parameter) value; + if (value instanceof Parameter parameter) { mappedValue = convertValue(parameter.getValue(), propertyField.getTypeHint()); typeHint = getTypeHint(mappedValue, actualType.getType(), parameter); - } else if (value instanceof ValueFunction) { - - ValueFunction valueFunction = (ValueFunction) value; + } else if (value instanceof ValueFunction valueFunction) { - mappedValue = convertValue(valueFunction.apply(Escaper.DEFAULT), propertyField.getTypeHint()); + mappedValue = valueFunction.transform(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); if (mappedValue == null) { return Assignments.value(column, SQL.nullLiteral()); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java index cd6908174e..780fdf0d9d 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java @@ -56,4 +56,16 @@ default Supplier toSupplier(Escaper escaper) { return () -> apply(escaper); } + + /** + * Transforms the inner value of the ValueFunction using the profided transformation. + * + * The default implementation just return the current {@literal ValueFunction}. + * This is not a valid implementation and serves just to maintain backward compatibility. + * + * @param transformation to be applied to the underlying value. + * @return a new {@literal ValueFunction}. + * @since 3.2 + */ + default ValueFunction transform(Function transformation) {return this;}; } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java similarity index 62% rename from spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java rename to spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java index 9e32867eda..2be3e8c371 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/SimpleValueFunction.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java @@ -21,14 +21,25 @@ import java.util.function.Function; -record SimpleValueFunction(Object value, Function modifier) implements ValueFunction { +/** + * Value function that has an underlying value and a modifier that gets applied after the escaper. + * + * @author Jens Schauder + * @since 3.2 + */ +record ModifyingValueFunction(Object value, Function modifier) implements ValueFunction { - static SimpleValueFunction of(Object value, Function modifier) { - return new SimpleValueFunction(value, modifier); + static ModifyingValueFunction of(Object value, Function modifier) { + return new ModifyingValueFunction(value, modifier); } @Override public String apply(Escaper escaper) { return modifier.apply(escaper.escape(value.toString())); } + + @Override + public ValueFunction transform(Function transformation) { + return new ModifyingValueFunction(transformation.apply(value), modifier); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java index 00657d3ca8..1261899b49 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java @@ -138,12 +138,12 @@ protected Object prepareParameterValue(@Nullable Object value, Class valueTyp switch (partType) { case STARTING_WITH: - return SimpleValueFunction.of(value, s -> s + "%"); + return ModifyingValueFunction.of(value, s -> s + "%"); case ENDING_WITH: - return SimpleValueFunction.of(value, s -> "%" + s); + return ModifyingValueFunction.of(value, s -> "%" + s); case CONTAINING: case NOT_CONTAINING: - return SimpleValueFunction.of(value, s -> "%" + s + "%"); + return ModifyingValueFunction.of(value, s -> "%" + s + "%"); default: return value; } From 1ed34e8931c3d0957556683140aad9a3ca644319 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 26 Sep 2023 10:22:29 +0200 Subject: [PATCH 7/9] Polishing. Simplify ValueFunction mapping. Remove invariants of findBy SQL generation in favor of the Condition-based variant. Reduce visibility. --- .../jdbc/core/convert/AggregateReader.java | 95 ++++++++----------- .../data/jdbc/core/convert/QueryMapper.java | 7 +- .../SingleQueryDataAccessStrategy.java | 5 +- ...SingleQueryFallbackDataAccessStrategy.java | 6 +- .../query/EscapingParameterSource.java | 7 +- .../repository/query/ParametrizedQuery.java | 13 ++- .../data/r2dbc/query/QueryMapper.java | 14 +-- .../data/r2dbc/query/UpdateMapper.java | 2 +- .../relational/core/query/ValueFunction.java | 16 ++-- .../SingleQuerySqlGenerator.java | 44 +-------- .../core/sqlgeneration/SqlGenerator.java | 14 +-- .../query/ModifyingValueFunction.java | 45 --------- .../query/ParameterMetadataProvider.java | 19 ++-- .../DerivedSqlIdentifierUnitTests.java | 14 ++- .../SingleQuerySqlGeneratorUnitTests.java | 24 +++-- 15 files changed, 109 insertions(+), 216 deletions(-) delete mode 100644 spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index da1373adbd..b51b457359 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -18,16 +18,16 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; import org.springframework.dao.IncorrectResultSizeDataAccessException; import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.mapping.AggregatePath; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.query.Criteria; import org.springframework.data.relational.core.query.CriteriaDefinition; import org.springframework.data.relational.core.query.Query; import org.springframework.data.relational.core.sql.Condition; @@ -36,6 +36,7 @@ import org.springframework.data.relational.core.sqlgeneration.SingleQuerySqlGenerator; import org.springframework.data.relational.core.sqlgeneration.SqlGenerator; import org.springframework.data.relational.domain.RowDocument; +import org.springframework.data.util.Streamable; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; @@ -43,7 +44,7 @@ /** * Reads complete Aggregates from the database, by generating appropriate SQL using a {@link SingleQuerySqlGenerator} - * through {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate}. Results are converterd into an + * through {@link org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate}. Results are converted into an * intermediate {@link RowDocumentResultSetExtractor RowDocument} and mapped via * {@link org.springframework.data.relational.core.conversion.RelationalConverter#read(Class, RowDocument)}. * @@ -55,7 +56,8 @@ class AggregateReader { private final RelationalPersistentEntity aggregate; - private final org.springframework.data.relational.core.sqlgeneration.SqlGenerator sqlGenerator; + private final Table table; + private final SqlGenerator sqlGenerator; private final JdbcConverter converter; private final NamedParameterJdbcOperations jdbcTemplate; private final RowDocumentResultSetExtractor extractor; @@ -66,6 +68,7 @@ class AggregateReader { this.converter = converter; this.aggregate = aggregate; this.jdbcTemplate = jdbcTemplate; + this.table = Table.create(aggregate.getQualifiedTableName()); this.sqlGenerator = new CachingSqlGenerator( new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate)); @@ -74,62 +77,58 @@ class AggregateReader { createPathToColumnMapping(aliasFactory)); } - public List findAll() { - return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); - } - @Nullable public T findById(Object id) { - id = converter.writeValue(id, aggregate.getRequiredIdProperty().getTypeInformation()); + Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).is(id)).limit(1); - return jdbcTemplate.query(sqlGenerator.findById(), Map.of("id", id), this::extractZeroOrOne); + return findOne(query); } - public List findAllById(Iterable ids) { + @Nullable + public T findOne(Query query) { - List convertedIds = new ArrayList<>(); - for (Object id : ids) { - convertedIds.add(converter.writeValue(id, aggregate.getRequiredIdProperty().getTypeInformation())); - } + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + Condition condition = createCondition(query, parameterSource); - return jdbcTemplate.query(sqlGenerator.findAllById(), Map.of("ids", convertedIds), this::extractAll); + return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractZeroOrOne); } - public List findAllBy(Query query) { + public List findAll() { + return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); + } - MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - BiFunction condition = createConditionSource(query, parameterSource); - return jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractAll); + public List findAllById(Iterable ids) { + + Collection identifiers = ids instanceof Collection idl ? idl : Streamable.of(ids).toList(); + Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)).limit(1); + + return findAll(query); } - public Optional findOneByQuery(Query query) { - - MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - BiFunction condition = createConditionSource(query, parameterSource); + public List findAll(Query query) { - return Optional.ofNullable( - jdbcTemplate.query(sqlGenerator.findAllByCondition(condition), parameterSource, this::extractZeroOrOne)); + MapSqlParameterSource parameterSource = new MapSqlParameterSource(); + Condition condition = createCondition(query, parameterSource); + return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractAll); } - private BiFunction createConditionSource(Query query, MapSqlParameterSource parameterSource) { + @Nullable + private Condition createCondition(Query query, MapSqlParameterSource parameterSource) { QueryMapper queryMapper = new QueryMapper(converter); - BiFunction condition = (table, aggregate) -> { - Optional criteria = query.getCriteria(); - return criteria - .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) - .orElse(null); - }; - return condition; + Optional criteria = query.getCriteria(); + return criteria + .map(criteriaDefinition -> queryMapper.getMappedObject(parameterSource, criteriaDefinition, table, aggregate)) + .orElse(null); } /** * Extracts a list of aggregates from the given {@link ResultSet} by utilizing the * {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms * to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract. - * + * * @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. * @return a {@code List} of aggregates, fully converted. * @throws SQLException @@ -195,21 +194,15 @@ public String keyColumn(AggregatePath path) { * @author Jens Schauder * @since 3.2 */ - static class CachingSqlGenerator implements org.springframework.data.relational.core.sqlgeneration.SqlGenerator { - - private final org.springframework.data.relational.core.sqlgeneration.SqlGenerator delegate; + static class CachingSqlGenerator implements SqlGenerator { + private final SqlGenerator delegate; private final String findAll; - private final String findById; - private final String findAllById; public CachingSqlGenerator(SqlGenerator delegate) { this.delegate = delegate; - - findAll = delegate.findAll(); - findById = delegate.findById(); - findAllById = delegate.findAllById(); + this.findAll = delegate.findAll(); } @Override @@ -218,18 +211,8 @@ public String findAll() { } @Override - public String findById() { - return findById; - } - - @Override - public String findAllById() { - return findAllById; - } - - @Override - public String findAllByCondition(BiFunction conditionSource) { - return delegate.findAllByCondition(conditionSource); + public String findAll(@Nullable Condition condition) { + return delegate.findAll(condition); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java index 547eac6716..6695647197 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/QueryMapper.java @@ -22,8 +22,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.function.Function; import org.springframework.data.domain.Sort; import org.springframework.data.jdbc.core.mapping.JdbcValue; @@ -35,7 +33,6 @@ import org.springframework.data.mapping.context.InvalidPersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.relational.core.dialect.Dialect; -import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.CriteriaDefinition; @@ -77,7 +74,7 @@ public QueryMapper(Dialect dialect, JdbcConverter converter) { Assert.notNull(converter, "JdbcConverter must not be null"); this.converter = converter; - this.mappingContext = (MappingContext) converter.getMappingContext(); + this.mappingContext = converter.getMappingContext(); } /** @@ -310,7 +307,7 @@ private Condition mapCondition(CriteriaDefinition criteria, MapSqlParameterSourc sqlType = getTypeHint(mappedValue, actualType.getType(), settableValue); } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())); + mappedValue = valueFunction.map(v -> convertValue(comparator, v, propertyField.getTypeHint())); sqlType = propertyField.getSqlType(); } else if (propertyField instanceof MetadataBackedField metadataBackedField // diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index c0e20c425f..a609619c2d 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -77,13 +77,12 @@ public Iterable findAll(Class domainType, Pageable pageable) { @Override public Optional findOne(Query query, Class domainType) { - return getReader(domainType).findOneByQuery(query); + return Optional.ofNullable(getReader(domainType).findOne(query)); } @Override public Iterable findAll(Query query, Class domainType) { - - return getReader(domainType).findAllBy(query); + return getReader(domainType).findAll(query); } @Override diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java index 0cb0b04638..9628588f7a 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryFallbackDataAccessStrategy.java @@ -87,6 +87,7 @@ public Iterable findAllById(Iterable ids, Class domainType) { return super.findAllById(ids, domainType); } + @Override public Optional findOne(Query query, Class domainType) { if (isSingleSelectQuerySupported(domainType) && isSingleSelectQuerySupported(query)) { @@ -137,11 +138,6 @@ private boolean entityQualifiesForSingleQueryLoading(Class entityType) { referenceFound = true; } - - // AggregateReferences aren't supported yet - // if (property.isAssociation()) { - // return false; - // } } return true; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java index 8f0fa6e818..16bdad90e6 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/EscapingParameterSource.java @@ -21,12 +21,13 @@ import org.springframework.jdbc.core.namedparam.SqlParameterSource; /** - * This {@link SqlParameterSource} will apply escaping to it's values. - * + * This {@link SqlParameterSource} will apply escaping to its values. + * * @author Jens Schauder * @since 3.2 */ -public class EscapingParameterSource implements SqlParameterSource { +class EscapingParameterSource implements SqlParameterSource { + private final SqlParameterSource parameterSource; private final Escaper escaper; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java index ac3c256d27..b2f5c6ac93 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/ParametrizedQuery.java @@ -15,12 +15,12 @@ */ package org.springframework.data.jdbc.repository.query; -import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.jdbc.core.namedparam.SqlParameterSource; /** - * Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the parameters. + * Value object encapsulating a query containing named parameters and a{@link SqlParameterSource} to bind the + * parameters. * * @author Mark Paluch * @author Jens Schauder @@ -41,13 +41,12 @@ String getQuery() { return query; } + SqlParameterSource getParameterSource(Escaper escaper) { + return new EscapingParameterSource(parameterSource, escaper); + } + @Override public String toString() { return this.query; } - - public SqlParameterSource getParameterSource(Escaper escaper) { - - return new EscapingParameterSource(parameterSource, escaper); - } } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 08adc38d74..0f233cdd64 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -176,9 +176,8 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer return expression; } - if (expression instanceof Column) { + if (expression instanceof Column column) { - Column column = (Column) expression; Field field = createPropertyField(entity, column.getName()); TableLike table = column.getTable(); @@ -186,9 +185,7 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer return column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable; } - if (expression instanceof SimpleFunction) { - - SimpleFunction function = (SimpleFunction) expression; + if (expression instanceof SimpleFunction function) { List arguments = function.getExpressions(); List mappedArguments = new ArrayList<>(arguments.size()); @@ -367,15 +364,14 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind Class typeHint; Comparator comparator = criteria.getComparator(); - if (criteria.getValue() instanceof Parameter) { - - Parameter parameter = (Parameter) criteria.getValue(); + if (criteria.getValue()instanceof Parameter parameter) { mappedValue = convertValue(comparator, parameter.getValue(), propertyField.getTypeHint()); typeHint = getTypeHint(mappedValue, actualType.getType(), parameter); } else if (criteria.getValue() instanceof ValueFunction valueFunction) { - mappedValue = valueFunction.transform(v -> convertValue(comparator, v, propertyField.getTypeHint())).apply(getEscaper(comparator)); + mappedValue = valueFunction.map(v -> convertValue(comparator, v, propertyField.getTypeHint())) + .apply(getEscaper(comparator)); typeHint = actualType.getType(); } else { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index b77d95bab9..372ed39048 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -118,7 +118,7 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable } else if (value instanceof ValueFunction valueFunction) { - mappedValue = valueFunction.transform(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); + mappedValue = valueFunction.map(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); if (mappedValue == null) { return Assignments.value(column, SQL.nullLiteral()); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java index 780fdf0d9d..8951ac2a81 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/query/ValueFunction.java @@ -58,14 +58,18 @@ default Supplier toSupplier(Escaper escaper) { } /** - * Transforms the inner value of the ValueFunction using the profided transformation. + * Return a new ValueFunction applying the given mapping {@link Function}. The mapping function is applied after + * applying {@link Escaper}. * - * The default implementation just return the current {@literal ValueFunction}. - * This is not a valid implementation and serves just to maintain backward compatibility. - * - * @param transformation to be applied to the underlying value. + * @param mapper the mapping function to apply to the value. + * @param the type of the value returned from the mapping function. * @return a new {@literal ValueFunction}. * @since 3.2 */ - default ValueFunction transform(Function transformation) {return this;}; + default ValueFunction map(Function mapper) { + + Assert.notNull(mapper, "Mapping function must not be null"); + + return escaper -> mapper.apply(this.apply(escaper)); + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java index 505a027dca..ff0a61f771 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java @@ -19,9 +19,9 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import org.springframework.data.mapping.PersistentProperty; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.PersistentPropertyPaths; @@ -46,7 +46,6 @@ public class SingleQuerySqlGenerator implements SqlGenerator { private final Dialect dialect; private final AliasFactory aliases; private final RelationalPersistentEntity aggregate; - private final Table table; public SingleQuerySqlGenerator(RelationalMappingContext context, AliasFactory aliasFactory, Dialect dialect, RelationalPersistentEntity aggregate) { @@ -55,47 +54,14 @@ public SingleQuerySqlGenerator(RelationalMappingContext context, AliasFactory al this.aliases = aliasFactory; this.dialect = dialect; this.aggregate = aggregate; - - this.table = Table.create(aggregate.getQualifiedTableName()); - } - - @Override - public String findAll() { - return createSelect(null); - } - - @Override - public String findById() { - - AggregatePath path = getRootIdPath(); - Condition condition = Conditions.isEqual(table.column(path.getColumnInfo().name()), Expressions.just(":id")); - - return createSelect(condition); } @Override - public String findAllById() { - - AggregatePath path = getRootIdPath(); - Condition condition = Conditions.in(table.column(path.getColumnInfo().name()), Expressions.just(":ids")); - - return createSelect(condition); - } - - @Override - public String findAllByCondition(BiFunction conditionSource) { - Condition condition = conditionSource.apply(table, aggregate); + public String findAll(@Nullable Condition condition) { return createSelect(condition); } - /** - * @return The {@link AggregatePath} to the id property of the aggregate root. - */ - private AggregatePath getRootIdPath() { - return context.getAggregatePath(aggregate).append(aggregate.getRequiredIdProperty()); - } - - String createSelect(Condition condition) { + String createSelect(@Nullable Condition condition) { AggregatePath rootPath = context.getAggregatePath(aggregate); QueryMeta queryMeta = createInlineQuery(rootPath, condition); @@ -168,7 +134,7 @@ private List createInlineQueries(PersistentPropertyPaths inlineQueries = new ArrayList<>(); - for (PersistentPropertyPath ppp : paths) { + for (PersistentPropertyPath ppp : paths) { QueryMeta queryMeta = createInlineQuery(context.getAggregatePath(ppp), null); inlineQueries.add(queryMeta); @@ -188,7 +154,7 @@ private List createInlineQueries(PersistentPropertyPaths entity = basePath.getRequiredLeafEntity(); Table table = Table.create(entity.getQualifiedTableName()); diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java index 80eb9a1a87..fe783882a5 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SqlGenerator.java @@ -15,11 +15,8 @@ */ package org.springframework.data.relational.core.sqlgeneration; -import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.sql.Condition; -import org.springframework.data.relational.core.sql.Table; - -import java.util.function.BiFunction; +import org.springframework.lang.Nullable; /** * Generates SQL statements for loading aggregates. @@ -28,13 +25,12 @@ * @since 3.2 */ public interface SqlGenerator { - String findAll(); - - String findById(); - String findAllById(); + default String findAll() { + return findAll(null); + } - String findAllByCondition(BiFunction conditionSource); + String findAll(@Nullable Condition condition); AliasFactory getAliasFactory(); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java deleted file mode 100644 index 2be3e8c371..0000000000 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ModifyingValueFunction.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.data.relational.repository.query; - -import org.springframework.data.relational.core.dialect.Escaper; -import org.springframework.data.relational.core.query.ValueFunction; - -import java.util.function.Function; - -/** - * Value function that has an underlying value and a modifier that gets applied after the escaper. - * - * @author Jens Schauder - * @since 3.2 - */ -record ModifyingValueFunction(Object value, Function modifier) implements ValueFunction { - - static ModifyingValueFunction of(Object value, Function modifier) { - return new ModifyingValueFunction(value, modifier); - } - - @Override - public String apply(Escaper escaper) { - return modifier.apply(escaper.escape(value.toString())); - } - - @Override - public ValueFunction transform(Function transformation) { - return new ModifyingValueFunction(transformation.apply(value), modifier); - } -} diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java index 1261899b49..2f781e89c3 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/repository/query/ParameterMetadataProvider.java @@ -20,6 +20,7 @@ import java.util.List; import org.springframework.data.relational.core.dialect.Escaper; +import org.springframework.data.relational.core.query.ValueFunction; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; import org.springframework.data.repository.query.parser.Part; @@ -136,16 +137,12 @@ protected Object prepareParameterValue(@Nullable Object value, Class valueTyp return value; } - switch (partType) { - case STARTING_WITH: - return ModifyingValueFunction.of(value, s -> s + "%"); - case ENDING_WITH: - return ModifyingValueFunction.of(value, s -> "%" + s); - case CONTAINING: - case NOT_CONTAINING: - return ModifyingValueFunction.of(value, s -> "%" + s + "%"); - default: - return value; - } + return switch (partType) { + case STARTING_WITH -> (ValueFunction) escaper -> escaper.escape(value.toString()) + "%"; + case ENDING_WITH -> (ValueFunction) escaper -> "%" + escaper.escape(value.toString()); + case CONTAINING, NOT_CONTAINING -> (ValueFunction) escaper -> "%" + escaper.escape(value.toString()) + + "%"; + default -> value; + }; } } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java index 5742a2c4ed..bb62ab7b91 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/mapping/DerivedSqlIdentifierUnitTests.java @@ -16,7 +16,6 @@ package org.springframework.data.relational.core.mapping; import static org.assertj.core.api.Assertions.*; -import static org.assertj.core.api.SoftAssertions.*; import org.junit.jupiter.api.Test; import org.springframework.data.relational.core.sql.IdentifierProcessing; @@ -44,7 +43,6 @@ public void quotedSimpleObjectIdentifierWithAdjustableLetterCasing() { assertThat(identifier.toSql(BRACKETS_LOWER_CASE)).isEqualTo("[somename]"); assertThat(identifier.getReference(BRACKETS_LOWER_CASE)).isEqualTo("someName"); assertThat(identifier.getReference()).isEqualTo("someName"); - } @Test // DATAJDBC-386 @@ -77,12 +75,12 @@ public void equality() { SqlIdentifier notSimple = SqlIdentifier.from(new DerivedSqlIdentifier("simple", false), new DerivedSqlIdentifier("not", false)); - assertSoftly(softly -> { + assertThat(basis).isEqualTo(equal).isEqualTo(SqlIdentifier.unquoted("simple")) + .hasSameHashCodeAs(SqlIdentifier.unquoted("simple")); + assertThat(equal).isEqualTo(basis); + assertThat(basis).isNotEqualTo(quoted); + assertThat(basis).isNotEqualTo(notSimple); - softly.assertThat(basis).isEqualTo(equal); - softly.assertThat(equal).isEqualTo(basis); - softly.assertThat(basis).isNotEqualTo(quoted); - softly.assertThat(basis).isNotEqualTo(notSimple); - }); + assertThat(quoted).isEqualTo(SqlIdentifier.quoted("SIMPLE")).hasSameHashCodeAs(SqlIdentifier.quoted("SIMPLE")); } } diff --git a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java index 5721ce2b42..ade6e0dad1 100644 --- a/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java +++ b/spring-data-relational/src/test/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorUnitTests.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.data.relational.core.sqlgeneration; import static org.springframework.data.relational.core.sqlgeneration.SqlAssert.*; @@ -28,7 +27,10 @@ import org.springframework.data.relational.core.dialect.PostgresDialect; import org.springframework.data.relational.core.mapping.AggregatePath; import org.springframework.data.relational.core.mapping.RelationalMappingContext; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.sql.Conditions; +import org.springframework.data.relational.core.sql.Table; /** * Tests for {@link SingleQuerySqlGenerator}. @@ -76,7 +78,8 @@ void createSelectForFindAll() { @Test // GH-1446 void createSelectForFindById() { - String sql = sqlGenerator.findById(); + Table table = Table.create(persistentEntity.getQualifiedTableName()); + String sql = sqlGenerator.findAll(table.column("id").isEqualTo(Conditions.just(":id"))); SqlAssert baseSelect = assertThatParsed(sql).hasInlineView(); @@ -94,13 +97,14 @@ void createSelectForFindById() { col("\"id\"").as(alias("id")), // col("\"name\"").as(alias("name")) // ) // - .extractWhereClause().isEqualTo("\"trivial_aggregate\".\"id\" = :id"); + .extractWhereClause().isEqualTo("\"trivial_aggregate\".id = :id"); } @Test // GH-1446 void createSelectForFindAllById() { - String sql = sqlGenerator.findAllById(); + Table table = Table.create(persistentEntity.getQualifiedTableName()); + String sql = sqlGenerator.findAll(table.column("id").in(Conditions.just(":ids"))); SqlAssert baseSelect = assertThatParsed(sql).hasInlineView(); @@ -118,7 +122,7 @@ void createSelectForFindAllById() { col("\"id\"").as(alias("id")), // col("\"name\"").as(alias("name")) // ) // - .extractWhereClause().isEqualTo("\"trivial_aggregate\".\"id\" IN (:ids)"); + .extractWhereClause().isEqualTo("\"trivial_aggregate\".id IN (:ids)"); } } @@ -133,7 +137,8 @@ private AggregateWithSingleReference() { @Test // GH-1446 void createSelectForFindById() { - String sql = sqlGenerator.findById(); + Table table = Table.create(persistentEntity.getQualifiedTableName()); + String sql = sqlGenerator.findAll(table.column("id").isEqualTo(Conditions.just(":id"))); String rootRowNumber = rnAlias(); String rootCount = rcAlias(); @@ -167,7 +172,7 @@ void createSelectForFindById() { col("\"id\"").as(alias("id")), // col("\"name\"").as(alias("name")) // ) // - .extractWhereClause().isEqualTo("\"single_reference_aggregate\".\"id\" = :id"); + .extractWhereClause().isEqualTo("\"single_reference_aggregate\".id = :id"); baseSelect.hasInlineViewSelectingFrom("\"trivial_aggregate\"") // .hasExactlyColumns( // rn(col("\"single_reference_aggregate\"")).as(trivialsRowNumber), // @@ -206,13 +211,14 @@ record SingleReferenceAggregate(@Id Long id, String name, List private class AbstractTestFixture { final Class aggregateRootType; final SingleQuerySqlGenerator sqlGenerator; + final RelationalPersistentEntity persistentEntity; final AliasFactory aliases; private AbstractTestFixture(Class aggregateRootType) { this.aggregateRootType = aggregateRootType; - this.sqlGenerator = new SingleQuerySqlGenerator(context, new AliasFactory(), dialect, - context.getRequiredPersistentEntity(aggregateRootType)); + this.persistentEntity = context.getRequiredPersistentEntity(aggregateRootType); + this.sqlGenerator = new SingleQuerySqlGenerator(context, new AliasFactory(), dialect, persistentEntity); this.aliases = sqlGenerator.getAliasFactory(); } From 83530f4f3b8c98e389bc1f58e96935165c406d66 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 26 Sep 2023 11:10:28 +0200 Subject: [PATCH 8/9] Add JMH benchmarks and optimizations to affected components. --- pom.xml | 95 +++++++++++++++++++ .../jdbc/core/convert/AggregateReader.java | 39 +------- .../core/sqlgeneration/BenchmarkSettings.java | 40 ++++++++ .../SingleQuerySqlGeneratorBenchmark.java | 66 +++++++++++++ .../core/mapping/DefaultAggregatePath.java | 34 +++++-- .../relational/core/sql/DefaultSelect.java | 7 +- .../core/sql/DefaultSelectBuilder.java | 11 ++- .../relational/core/sql/SelectBuilder.java | 28 +++++- .../core/sql/render/TypedSubtreeVisitor.java | 31 +++++- .../SingleQuerySqlGenerator.java | 6 +- 10 files changed, 290 insertions(+), 67 deletions(-) create mode 100644 spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/BenchmarkSettings.java create mode 100644 spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java diff --git a/pom.xml b/pom.xml index 400d09babe..c57d0287a3 100644 --- a/pom.xml +++ b/pom.xml @@ -47,6 +47,9 @@ 4.2.0 1.0.1 + + 1.37 + 0.4.0.BUILD-SNAPSHOT 2017 @@ -154,6 +157,98 @@ + + + jmh + + + com.github.mp911de.microbenchmark-runner + microbenchmark-runner-junit5 + ${mbr.version} + test + + + org.openjdk.jmh + jmh-core + ${jmh.version} + test + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + test + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.3.0 + + + add-source + generate-sources + + add-test-source + + + + src/jmh/java + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + org.apache.maven.plugins + maven-failsafe-plugin + + true + + + + org.codehaus.mojo + exec-maven-plugin + 3.1.0 + + + run-benchmarks + pre-integration-test + + exec + + + test + java + + -classpath + + org.openjdk.jmh.Main + .* + + + + + + + + + + jitpack.io + https://jitpack.io + + + diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index b51b457359..c765c2eb6b 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -69,10 +69,7 @@ class AggregateReader { this.aggregate = aggregate; this.jdbcTemplate = jdbcTemplate; this.table = Table.create(aggregate.getQualifiedTableName()); - - this.sqlGenerator = new CachingSqlGenerator( - new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate)); - + this.sqlGenerator = new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate); this.extractor = new RowDocumentResultSetExtractor(converter.getMappingContext(), createPathToColumnMapping(aliasFactory)); } @@ -187,38 +184,4 @@ public String keyColumn(AggregatePath path) { }; } - /** - * A wrapper for the {@link org.springframework.data.relational.core.sqlgeneration.SqlGenerator} that caches the - * generated statements. - * - * @author Jens Schauder - * @since 3.2 - */ - static class CachingSqlGenerator implements SqlGenerator { - - private final SqlGenerator delegate; - private final String findAll; - - public CachingSqlGenerator(SqlGenerator delegate) { - - this.delegate = delegate; - this.findAll = delegate.findAll(); - } - - @Override - public String findAll() { - return findAll; - } - - @Override - public String findAll(@Nullable Condition condition) { - return delegate.findAll(condition); - } - - @Override - public AliasFactory getAliasFactory() { - return delegate.getAliasFactory(); - } - - } } diff --git a/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/BenchmarkSettings.java b/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/BenchmarkSettings.java new file mode 100644 index 0000000000..439824bf3c --- /dev/null +++ b/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/BenchmarkSettings.java @@ -0,0 +1,40 @@ +/* + * Copyright 2019-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.data.relational.core.sqlgeneration; + +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Warmup; + +/** + * Global benchmark settings. + * + * @author Mark Paluch + */ +@Warmup(iterations = 5, time = 2000, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 5, time = 1000, timeUnit = TimeUnit.MILLISECONDS) +@Fork(value = 1, warmups = 0) +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.SECONDS) +public abstract class BenchmarkSettings { + +} diff --git a/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java b/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java new file mode 100644 index 0000000000..f8a1399b5a --- /dev/null +++ b/spring-data-relational/src/jmh/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGeneratorBenchmark.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.relational.core.sqlgeneration; + +import jmh.mbr.junit5.Microbenchmark; + +import java.util.List; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.springframework.data.annotation.Id; +import org.springframework.data.relational.core.dialect.PostgresDialect; +import org.springframework.data.relational.core.mapping.RelationalMappingContext; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; + +/** + * Benchmark for {@link SingleQuerySqlGenerator}. + * + * @author Mark Paluch + */ +@Microbenchmark +public class SingleQuerySqlGeneratorBenchmark extends BenchmarkSettings { + + @Benchmark + public String findAll(StateHolder state) { + return new SingleQuerySqlGenerator(state.context, state.aliasFactory, PostgresDialect.INSTANCE, + state.persistentEntity).findAll(null); + } + + @State(Scope.Benchmark) + public static class StateHolder { + + RelationalMappingContext context = new RelationalMappingContext(); + + RelationalPersistentEntity persistentEntity; + + AliasFactory aliasFactory = new AliasFactory(); + + @Setup + public void setup() { + persistentEntity = context.getRequiredPersistentEntity(SingleReferenceAggregate.class); + } + } + + record TrivialAggregate(@Id Long id, String name) { + } + + record SingleReferenceAggregate(@Id Long id, String name, List trivials) { + } + +} diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/DefaultAggregatePath.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/DefaultAggregatePath.java index f924ba9d39..015857e985 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/DefaultAggregatePath.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/mapping/DefaultAggregatePath.java @@ -20,6 +20,7 @@ import java.util.Objects; import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.util.Lazy; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -38,6 +39,10 @@ class DefaultAggregatePath implements AggregatePath { private final @Nullable PersistentPropertyPath path; + private final Lazy tableInfo = Lazy.of(() -> TableInfo.of(this)); + + private final Lazy columnInfo = Lazy.of(() -> ColumnInfo.of(this)); + @SuppressWarnings("unchecked") DefaultAggregatePath(RelationalMappingContext context, PersistentPropertyPath path) { @@ -189,14 +194,24 @@ private AggregatePath getTableOwningAncestor() { return AggregatePathTraversal.getTableOwningPath(this); } + /** + * Creates an {@link Iterator} that iterates over the current path and all ancestors. It will start with the current + * path, followed by its parent until ending with the root. + */ @Override - public String toString() { - return "AggregatePath[" - + (rootType == null ? path.getBaseProperty().getOwner().getType().getName() : rootType.getName()) + "]" - + ((isRoot()) ? "/" : path.toDotPath()); + public Iterator iterator() { + return new AggregatePathIterator(this); } + @Override + public TableInfo getTableInfo() { + return this.tableInfo.get(); + } + @Override + public ColumnInfo getColumnInfo() { + return this.columnInfo.get(); + } @Override public boolean equals(Object o) { @@ -215,13 +230,12 @@ public int hashCode() { return Objects.hash(context, rootType, path); } - /** - * Creates an {@link Iterator} that iterates over the current path and all ancestors. It will start with the current - * path, followed by its parent until ending with the root. - */ + @Override - public Iterator iterator() { - return new AggregatePathIterator(this); + public String toString() { + return "AggregatePath[" + + (rootType == null ? path.getBaseProperty().getOwner().getType().getName() : rootType.getName()) + "]" + + ((isRoot()) ? "/" : path.toDotPath()); } private static class AggregatePathIterator implements Iterator { diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelect.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelect.java index 204c741187..e239521488 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelect.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelect.java @@ -19,6 +19,7 @@ import java.util.Collections; import java.util.List; import java.util.OptionalLong; +import java.util.function.Consumer; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -92,15 +93,17 @@ public void visit(Visitor visitor) { Assert.notNull(visitor, "Visitor must not be null"); + Consumer action = it -> it.visit(visitor); + visitor.enter(this); selectList.visit(visitor); from.visit(visitor); - joins.forEach(it -> it.visit(visitor)); + joins.forEach(action); visitIfNotNull(where, visitor); - orderBy.forEach(it -> it.visit(visitor)); + orderBy.forEach(action); visitor.leave(this); } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelectBuilder.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelectBuilder.java index e06da61327..08ca542af2 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelectBuilder.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/DefaultSelectBuilder.java @@ -200,11 +200,14 @@ public SelectLock lock(LockMode lockMode) { } @Override - public Select build() { + public Select build(boolean validate) { DefaultSelect select = new DefaultSelect(distinct, selectList, from, limit, offset, joins, where, orderBy, lockMode); - SelectValidator.validate(select); + + if (validate) { + SelectValidator.validate(select); + } return select; } @@ -359,9 +362,9 @@ public SelectLock lock(LockMode lockMode) { } @Override - public Select build() { + public Select build(boolean validate) { selectBuilder.join(finishJoin()); - return selectBuilder.build(); + return selectBuilder.build(validate); } } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/SelectBuilder.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/SelectBuilder.java index 140eb7ad14..8f6c8c2075 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/SelectBuilder.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/SelectBuilder.java @@ -297,6 +297,7 @@ interface SelectFromAndJoin * @param offset row offset, zero-based. * @return {@code this} builder. */ + @Override SelectFromAndJoin limitOffset(long limit, long offset); /** @@ -305,6 +306,7 @@ interface SelectFromAndJoin * @param limit rows to read. * @return {@code this} builder. */ + @Override SelectFromAndJoin limit(long limit); /** @@ -313,6 +315,7 @@ interface SelectFromAndJoin * @param offset start offset. * @return {@code this} builder. */ + @Override SelectFromAndJoin offset(long offset); } @@ -331,6 +334,7 @@ interface SelectFromAndJoinCondition * @param offset row offset, zero-based. * @return {@code this} builder. */ + @Override SelectFromAndJoin limitOffset(long limit, long offset); /** @@ -339,6 +343,7 @@ interface SelectFromAndJoinCondition * @param limit rows to read. * @return {@code this} builder. */ + @Override SelectFromAndJoin limit(long limit); /** @@ -347,6 +352,7 @@ interface SelectFromAndJoinCondition * @param offset start offset. * @return {@code this} builder. */ + @Override SelectFromAndJoin offset(long offset); } @@ -488,11 +494,11 @@ interface SelectJoin extends SelectLock, BuildSelect { SelectOn leftOuterJoin(TableLike table); /** - * Declar a join, where the join type ({@code INNER}, {@code LEFT OUTER}, {@code RIGHT OUTER}, {@code FULL OUTER}) + * Declare a join, where the join type ({@code INNER}, {@code LEFT OUTER}, {@code RIGHT OUTER}, {@code FULL OUTER}) * is specified by an extra argument. - * + * * @param table the table to join. Must not be {@literal null}. - * @param joinType the type of joi. Must not be {@literal null}. + * @param joinType the type of join. Must not be {@literal null}. * @return {@code this} builder. */ SelectOn join(TableLike table, Join.JoinType joinType); @@ -577,8 +583,20 @@ interface BuildSelect { * Build the {@link Select} statement and verify basic relationship constraints such as all referenced columns have * a {@code FROM} or {@code JOIN} table import. * - * @return the build and immutable {@link Select} statement. + * @return the built and immutable {@link Select} statement. + */ + default Select build() { + return build(true); + } + + /** + * Build the {@link Select} statement. + * + * @param validate whether to validate the generated select by checking basic relationship constraints such as all + * referenced columns have a {@code FROM} or {@code JOIN} table import. + * @return the built and immutable {@link Select} statement. + * @since 3.2 */ - Select build(); + Select build(boolean validate); } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/TypedSubtreeVisitor.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/TypedSubtreeVisitor.java index 26d1e9d178..23d7d0e47e 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/TypedSubtreeVisitor.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sql/render/TypedSubtreeVisitor.java @@ -21,6 +21,7 @@ import org.springframework.data.relational.core.sql.Visitable; import org.springframework.data.relational.core.sql.Visitor; import org.springframework.lang.Nullable; +import org.springframework.util.ConcurrentReferenceHashMap; /** * Type-filtering {@link DelegatingVisitor visitor} applying a {@link Class type filter} derived from the generic type @@ -37,28 +38,36 @@ * {@link Visitable}. * *

- * + * * @author Mark Paluch * @since 1.1 * @see FilteredSubtreeVisitor */ abstract class TypedSubtreeVisitor extends DelegatingVisitor { + private static final ConcurrentReferenceHashMap, ResolvableType> refCache = new ConcurrentReferenceHashMap<>(); + private static final ConcurrentReferenceHashMap, Assignable> assignable = new ConcurrentReferenceHashMap<>(); + private final ResolvableType type; private @Nullable Visitable currentSegment; + enum Assignable { + YES, NO, + } + /** * Creates a new {@link TypedSubtreeVisitor}. */ TypedSubtreeVisitor() { - this.type = ResolvableType.forClass(getClass()).as(TypedSubtreeVisitor.class).getGeneric(0); + this.type = refCache.computeIfAbsent(this.getClass(), + key -> ResolvableType.forClass(key).as(TypedSubtreeVisitor.class).getGeneric(0)); } /** * Creates a new {@link TypedSubtreeVisitor} with an explicitly provided type. */ - TypedSubtreeVisitor(Class type) { - this.type = ResolvableType.forType(type); + TypedSubtreeVisitor(Class type) { + this.type = refCache.computeIfAbsent(type, key -> ResolvableType.forClass(type)); } /** @@ -117,7 +126,7 @@ public final Delegation doEnter(Visitable segment) { if (currentSegment == null) { - if (this.type.isInstance(segment)) { + if (isAssignable(this.type, segment)) { currentSegment = segment; return enterMatched((T) segment); @@ -142,4 +151,16 @@ public final Delegation doLeave(Visitable segment) { return leaveNested(segment); } } + + private static boolean isAssignable(ResolvableType type, Visitable segment) { + + Assignable assignable = TypedSubtreeVisitor.assignable.get(segment.getClass()); + + if (assignable == null) { + assignable = type.isInstance(segment) ? Assignable.YES : Assignable.NO; + TypedSubtreeVisitor.assignable.put(segment.getClass(), assignable); + } + + return assignable == Assignable.YES; + } } diff --git a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java index ff0a61f771..9326a55f1f 100644 --- a/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java +++ b/spring-data-relational/src/main/java/org/springframework/data/relational/core/sqlgeneration/SingleQuerySqlGenerator.java @@ -104,7 +104,7 @@ String createSelect(@Nullable Condition condition) { finalColumns.add(rootIdExpression); Select fullQuery = StatementBuilder.select(finalColumns).from(inlineQuery).orderBy(rootIdExpression, just("rn")) - .build(); + .build(false); return SqlRenderer.create(new RenderContextFactory(dialect).createRenderContext()).render(fullQuery); } @@ -118,7 +118,7 @@ private InlineQuery createMainSelect(List columns, AggregatePath roo select = applyJoins(rootPath, inlineQueries, select); SelectBuilder.BuildSelect buildSelect = applyWhereCondition(rootPath, inlineQueries, select); - Select mainSelect = buildSelect.build(); + Select mainSelect = buildSelect.build(false); return InlineQuery.create(mainSelect, "main"); } @@ -215,7 +215,7 @@ private QueryMeta createInlineQuery(AggregatePath basePath, @Nullable Condition SelectBuilder.BuildSelect buildSelect = condition != null ? select.where(condition) : select; - InlineQuery inlineQuery = InlineQuery.create(buildSelect.build(), + InlineQuery inlineQuery = InlineQuery.create(buildSelect.build(false), aliases.getTableAlias(context.getAggregatePath(entity))); return QueryMeta.of(basePath, inlineQuery, columnAliases, just(id), just(backReferenceAlias), just(keyAlias), just(rowNumberAlias), just(rowCountAlias)); From efd8052aa86d9491a7db83457a846c1036d2b1fd Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 26 Sep 2023 11:34:26 +0200 Subject: [PATCH 9/9] Polishing. --- .../jdbc/core/convert/AggregateReader.java | 75 ++++++++++--------- .../RowDocumentResultSetExtractor.java | 2 +- .../SingleQueryDataAccessStrategy.java | 13 ++-- ...JdbcAggregateTemplateIntegrationTests.java | 18 ++++- 4 files changed, 63 insertions(+), 45 deletions(-) diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java index c765c2eb6b..4070894735 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/AggregateReader.java @@ -37,10 +37,10 @@ import org.springframework.data.relational.core.sqlgeneration.SqlGenerator; import org.springframework.data.relational.domain.RowDocument; import org.springframework.data.util.Streamable; +import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** * Reads complete Aggregates from the database, by generating appropriate SQL using a {@link SingleQuerySqlGenerator} @@ -53,13 +53,14 @@ * @author Mark Paluch * @since 3.2 */ -class AggregateReader { +class AggregateReader implements PathToColumnMapping { private final RelationalPersistentEntity aggregate; private final Table table; private final SqlGenerator sqlGenerator; private final JdbcConverter converter; private final NamedParameterJdbcOperations jdbcTemplate; + private final AliasFactory aliasFactory; private final RowDocumentResultSetExtractor extractor; AggregateReader(Dialect dialect, JdbcConverter converter, AliasFactory aliasFactory, @@ -70,8 +71,25 @@ class AggregateReader { this.jdbcTemplate = jdbcTemplate; this.table = Table.create(aggregate.getQualifiedTableName()); this.sqlGenerator = new SingleQuerySqlGenerator(converter.getMappingContext(), aliasFactory, dialect, aggregate); - this.extractor = new RowDocumentResultSetExtractor(converter.getMappingContext(), - createPathToColumnMapping(aliasFactory)); + this.aliasFactory = aliasFactory; + this.extractor = new RowDocumentResultSetExtractor(converter.getMappingContext(), this); + } + + @Override + public String column(AggregatePath path) { + + String alias = aliasFactory.getColumnAlias(path); + + if (alias == null) { + throw new IllegalStateException(String.format("Alias for '%s' must not be null", path)); + } + + return alias; + } + + @Override + public String keyColumn(AggregatePath path) { + return aliasFactory.getKeyAlias(path); } @Nullable @@ -84,30 +102,34 @@ public T findById(Object id) { @Nullable public T findOne(Query query) { - - MapSqlParameterSource parameterSource = new MapSqlParameterSource(); - Condition condition = createCondition(query, parameterSource); - - return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractZeroOrOne); - } - - public List findAll() { - return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); + return doFind(query, this::extractZeroOrOne); } public List findAllById(Iterable ids) { Collection identifiers = ids instanceof Collection idl ? idl : Streamable.of(ids).toList(); - Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)).limit(1); + Query query = Query.query(Criteria.where(aggregate.getRequiredIdProperty().getName()).in(identifiers)); return findAll(query); } + @SuppressWarnings("ConstantConditions") + public List findAll() { + return jdbcTemplate.query(sqlGenerator.findAll(), this::extractAll); + } + public List findAll(Query query) { + return doFind(query, this::extractAll); + } + + @SuppressWarnings("ConstantConditions") + private R doFind(Query query, ResultSetExtractor extractor) { MapSqlParameterSource parameterSource = new MapSqlParameterSource(); Condition condition = createCondition(query, parameterSource); - return jdbcTemplate.query(sqlGenerator.findAll(condition), parameterSource, this::extractAll); + String sql = sqlGenerator.findAll(condition); + + return jdbcTemplate.query(sql, parameterSource, extractor); } @Nullable @@ -128,7 +150,7 @@ private Condition createCondition(Query query, MapSqlParameterSource parameterSo * * @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. * @return a {@code List} of aggregates, fully converted. - * @throws SQLException + * @throws SQLException on underlying JDBC errors. */ private List extractAll(ResultSet rs) throws SQLException { @@ -146,10 +168,10 @@ private List extractAll(ResultSet rs) throws SQLException { * {@link RowDocumentResultSetExtractor} and the {@link JdbcConverter}. When used as a method reference this conforms * to the {@link org.springframework.jdbc.core.ResultSetExtractor} contract. * - * @param @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. + * @param rs the {@link ResultSet} from which to extract the data. Must not be {(}@literal null}. * @return The single instance when the conversion results in exactly one instance. If the {@literal ResultSet} is * empty, null is returned. - * @throws SQLException + * @throws SQLException on underlying JDBC errors. * @throws IncorrectResultSizeDataAccessException when the conversion yields more than one instance. */ @Nullable @@ -167,21 +189,4 @@ private T extractZeroOrOne(ResultSet rs) throws SQLException { return null; } - private PathToColumnMapping createPathToColumnMapping(AliasFactory aliasFactory) { - return new PathToColumnMapping() { - @Override - public String column(AggregatePath path) { - - String alias = aliasFactory.getColumnAlias(path); - Assert.notNull(alias, () -> "alias for >" + path + "< must not be null"); - return alias; - } - - @Override - public String keyColumn(AggregatePath path) { - return aliasFactory.getKeyAlias(path); - } - }; - } - } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/RowDocumentResultSetExtractor.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/RowDocumentResultSetExtractor.java index cb3df05bc9..45b264050d 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/RowDocumentResultSetExtractor.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/RowDocumentResultSetExtractor.java @@ -159,7 +159,7 @@ private class RowDocumentIterator implements Iterator { */ private boolean hasNext; - RowDocumentIterator(RelationalPersistentEntity entity, ResultSet resultSet) throws SQLException { + RowDocumentIterator(RelationalPersistentEntity entity, ResultSet resultSet) { ResultSetAdapter adapter = ResultSetAdapter.INSTANCE; diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java index a609619c2d..d5fc206e0c 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SingleQueryDataAccessStrategy.java @@ -16,6 +16,7 @@ package org.springframework.data.jdbc.core.convert; +import java.util.List; import java.util.Optional; import org.springframework.data.domain.Pageable; @@ -56,22 +57,22 @@ public T findById(Object id, Class domainType) { } @Override - public Iterable findAll(Class domainType) { + public List findAll(Class domainType) { return getReader(domainType).findAll(); } @Override - public Iterable findAllById(Iterable ids, Class domainType) { + public List findAllById(Iterable ids, Class domainType) { return getReader(domainType).findAllById(ids); } @Override - public Iterable findAll(Class domainType, Sort sort) { + public List findAll(Class domainType, Sort sort) { throw new UnsupportedOperationException(); } @Override - public Iterable findAll(Class domainType, Pageable pageable) { + public List findAll(Class domainType, Pageable pageable) { throw new UnsupportedOperationException(); } @@ -81,12 +82,12 @@ public Optional findOne(Query query, Class domainType) { } @Override - public Iterable findAll(Query query, Class domainType) { + public List findAll(Query query, Class domainType) { return getReader(domainType).findAll(query); } @Override - public Iterable findAll(Query query, Class domainType, Pageable pageable) { + public List findAll(Query query, Class domainType, Pageable pageable) { throw new UnsupportedOperationException(); } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java index e76825b29e..858a69a791 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/AbstractJdbcAggregateTemplateIntegrationTests.java @@ -23,8 +23,16 @@ import static org.springframework.data.jdbc.testing.TestDatabaseFeatures.Feature.*; import java.time.LocalDateTime; -import java.util.*; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.IntStream; @@ -49,6 +57,7 @@ import org.springframework.data.jdbc.core.convert.JdbcConverter; import org.springframework.data.jdbc.testing.EnabledOnFeature; import org.springframework.data.jdbc.testing.IntegrationTest; +import org.springframework.data.jdbc.testing.TestClass; import org.springframework.data.jdbc.testing.TestConfiguration; import org.springframework.data.jdbc.testing.TestDatabaseFeatures; import org.springframework.data.mapping.context.InvalidPersistentPropertyPath; @@ -63,6 +72,7 @@ import org.springframework.data.relational.core.query.Query; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.ContextConfiguration; /** * Integration tests for {@link JdbcAggregateTemplate}. @@ -1927,8 +1937,8 @@ static class WithInsertOnly { static class Config { @Bean - Class testClass() { - return JdbcAggregateTemplateIntegrationTests.class; + TestClass testClass() { + return TestClass.of(JdbcAggregateTemplateIntegrationTests.class); } @Bean @@ -1938,9 +1948,11 @@ JdbcAggregateOperations operations(ApplicationEventPublisher publisher, Relation } } + @ContextConfiguration(classes = Config.class) static class JdbcAggregateTemplateIntegrationTests extends AbstractJdbcAggregateTemplateIntegrationTests {} @ActiveProfiles(value = PROFILE_SINGLE_QUERY_LOADING) + @ContextConfiguration(classes = Config.class) static class JdbcAggregateTemplateSingleQueryLoadingIntegrationTests extends AbstractJdbcAggregateTemplateIntegrationTests {