Skip to content

Commit

Permalink
Add ANN options to RowFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
adelapena committed Jan 22, 2025
1 parent b6bedb9 commit 42a25dc
Show file tree
Hide file tree
Showing 8 changed files with 452 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import org.apache.cassandra.cql3.*;
import org.apache.cassandra.cql3.functions.Function;
import org.apache.cassandra.cql3.statements.Bound;
import org.apache.cassandra.cql3.statements.SelectOptions;
import org.apache.cassandra.cql3.statements.StatementType;
import org.apache.cassandra.db.*;
import org.apache.cassandra.index.ANNOptions;
import org.apache.cassandra.db.filter.RowFilter;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.DecimalType;
Expand Down Expand Up @@ -88,6 +90,7 @@ public class StatementRestrictions
public static final String VECTOR_INDEX_PRESENT_NOT_SUPPORT_GEO_DISTANCE_MESSAGE =
"Vector index present, but configuration does not support GEO_DISTANCE queries. GEO_DISTANCE requires similarity_function 'euclidean'";
public static final String VECTOR_INDEXES_UNSUPPORTED_OP_MESSAGE = "Vector indexes only support ANN and GEO_DISTANCE queries";
public static final String ANN_OPTIONS_WITHOUT_ORDER_BY_ANN = "ANN options specified without ORDER BY ... ANN OF ...";

/**
* The Column Family meta data
Expand Down Expand Up @@ -985,12 +988,25 @@ private boolean hasUnrestrictedClusteringColumns()
return table.clusteringColumns().size() != clusteringColumnsRestrictions.size();
}

public RowFilter getRowFilter(IndexRegistry indexManager, QueryOptions options, QueryState queryState)
public RowFilter getRowFilter(IndexRegistry indexManager, QueryOptions options, QueryState queryState, SelectOptions selectOptions)
{
ANNOptions annOptions = selectOptions.getANNOptions();

if (filterRestrictions.isEmpty() && children.isEmpty())
{
if (annOptions != ANNOptions.NONE)
throw new InvalidRequestException(ANN_OPTIONS_WITHOUT_ORDER_BY_ANN);

return RowFilter.NONE;
}

RowFilter rowFilter = RowFilter.builder(indexManager, annOptions)
.buildFromRestrictions(this, table, options, queryState);

if (annOptions != ANNOptions.NONE && !rowFilter.hasANN())
throw new InvalidRequestException(ANN_OPTIONS_WITHOUT_ORDER_BY_ANN);

return RowFilter.builder(indexManager).buildFromRestrictions(this, table, options, queryState);
return rowFilter;
}

/**
Expand Down
31 changes: 9 additions & 22 deletions src/java/org/apache/cassandra/cql3/statements/SelectOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@
package org.apache.cassandra.cql3.statements;

import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.exceptions.RequestValidationException;
import org.apache.cassandra.exceptions.SyntaxException;
import org.apache.cassandra.index.ANNOptions;

/**
* {@code WITH option1=... AND option2=...} options for SELECT statements.
Expand All @@ -33,28 +29,19 @@ public class SelectOptions extends PropertyDefinitions
{
public static final SelectOptions EMPTY = new SelectOptions();

private static final Logger logger = LoggerFactory.getLogger(SelectOptions.class);

private static final String INDEX_OPTIONS = "index_options";

private static final Set<String> keywords = new HashSet<>();
private static final Set<String> obsoleteKeywords = new HashSet<>();

static
{
keywords.add(INDEX_OPTIONS);
}
private static final Set<String> keywords = Collections.singleton(ANNOptions.SELECT_OPTIONS_NAME);

public void validate() throws RequestValidationException
{
validate(keywords, obsoleteKeywords);
Map<String, String> indexOptions = getIndexOptions();
logger.info("{}}: {}", INDEX_OPTIONS, indexOptions); // TODO: remove after development
validate(keywords, Collections.emptySet());
getANNOptions();
}

private Map<String, String> getIndexOptions() throws SyntaxException
/**
* @return the ANN options within these options, or {@link ANNOptions#NONE} if no options are present
*/
public ANNOptions getANNOptions() throws SyntaxException
{
Map<String, String> options = getMap(INDEX_OPTIONS);
return options == null ? Collections.emptyMap() : options;
return ANNOptions.fromSelectOptions(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,14 @@ private boolean canSkipPaging(DataLimits userLimits, PageSize pageSize, boolean
topK;
}

@Override
public ResultMessage.Rows execute(QueryState queryState, QueryOptions options, long queryStartNanoTime)
{
ConsistencyLevel cl = options.getConsistency();
checkNotNull(cl, "Invalid empty consistency level");

cl.validateForRead();
validateQueryOptions(queryState, options);
selectOptions.validate();

int nowInSec = options.getNowInSeconds(queryState);
int userLimit = getLimit(options);
Expand Down Expand Up @@ -426,6 +426,8 @@ public ReadQuery getQuery(QueryState queryState,
checkFalse(userOffset != NO_OFFSET, String.format(TOPK_OFFSET_ERROR, userOffset));
}

selectOptions.validate();

return query;
}

Expand Down Expand Up @@ -607,6 +609,7 @@ private ResultMessage.Rows processResults(PartitionIterator partitions,
return new ResultMessage.Rows(rset);
}

@Override
public ResultMessage.Rows executeLocally(QueryState state, QueryOptions options) throws RequestExecutionException, RequestValidationException
{
return executeInternal(state, options, options.getNowInSeconds(state), System.nanoTime());
Expand Down Expand Up @@ -998,7 +1001,7 @@ private NavigableSet<Clustering<?>> getRequestedRows(QueryOptions options, Query
public RowFilter getRowFilter(QueryOptions options, QueryState state) throws InvalidRequestException
{
IndexRegistry indexRegistry = IndexRegistry.obtain(table);
return restrictions.getRowFilter(indexRegistry, options, state);
return restrictions.getRowFilter(indexRegistry, options, state, selectOptions);
}

private ResultSet process(PartitionIterator partitions,
Expand Down
6 changes: 6 additions & 0 deletions src/java/org/apache/cassandra/db/ReadCommand.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.index.ANNOptions;
import org.apache.cassandra.db.filter.ClusteringIndexFilter;
import org.apache.cassandra.db.filter.ColumnFilter;
import org.apache.cassandra.db.filter.DataLimits;
Expand Down Expand Up @@ -779,6 +780,11 @@ public String toCQLString()

if (limits() != DataLimits.NONE)
sb.append(' ').append(limits());

ANNOptions annOptions = rowFilter().annOptions;
if (annOptions != ANNOptions.NONE)
sb.append(" WITH ").append(annOptions.toCQLString());

return sb.toString();
}

Expand Down
57 changes: 41 additions & 16 deletions src/java/org/apache/cassandra/db/filter/RowFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.cassandra.db.transform.Transformation;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.guardrails.Guardrails;
import org.apache.cassandra.index.ANNOptions;
import org.apache.cassandra.index.Index;
import org.apache.cassandra.index.IndexRegistry;
import org.apache.cassandra.index.sai.utils.GeoUtil;
Expand Down Expand Up @@ -79,13 +80,16 @@ public class RowFilter
private static final Logger logger = LoggerFactory.getLogger(RowFilter.class);

public static final Serializer serializer = new Serializer();
public static final RowFilter NONE = new RowFilter(FilterElement.NONE);
public static final RowFilter NONE = new RowFilter(FilterElement.NONE, ANNOptions.NONE);

protected final FilterElement root;
private final FilterElement root;

protected RowFilter(FilterElement root)
public final ANNOptions annOptions;

protected RowFilter(FilterElement root, ANNOptions annOptions)
{
this.root = root;
this.annOptions = annOptions;
}

public FilterElement root()
Expand All @@ -101,6 +105,19 @@ public List<Expression> expressions()
return root.traversedExpressions();
}

/**
* @return {@code true} if this filter contains any expression with an ANN operator, {@code false} otherwise.
*/
public boolean hasANN()
{
for (Expression expression : root.expressions()) // ANN expressions are always on the first tree level
{
if (expression.operator == Operator.ANN)
return true;
}
return false;
}

/**
* @return {@code true} if this filter contains any disjunction, {@code false} otherwise.
*/
Expand Down Expand Up @@ -267,7 +284,7 @@ public RowFilter without(Expression expression)
if (root.size() == 1)
return RowFilter.NONE;

return new RowFilter(root.filter(e -> !e.equals(expression)));
return new RowFilter(root.filter(e -> !e.equals(expression)), annOptions);
}

public RowFilter withoutExpressions()
Expand All @@ -280,12 +297,12 @@ public RowFilter withoutExpressions()
*/
public RowFilter withoutDisjunctions()
{
return new RowFilter(root.withoutDisjunctions());
return new RowFilter(root.withoutDisjunctions(), annOptions);
}

public RowFilter restrict(Predicate<Expression> filter)
{
return new RowFilter(root.filter(filter));
return new RowFilter(root.filter(filter), annOptions);
}

public boolean isEmpty()
Expand All @@ -301,38 +318,43 @@ public String toString()

public static Builder builder()
{
return new Builder(null);
return new Builder(null, ANNOptions.NONE);
}

public static Builder builder(IndexRegistry indexRegistry)
public static Builder builder(IndexRegistry indexRegistry, ANNOptions annOptions)
{
return new Builder(indexRegistry);
return new Builder(indexRegistry, annOptions);
}

public static class Builder
{
private FilterElement.Builder current = new FilterElement.Builder(false);

private final IndexRegistry indexRegistry;
private final ANNOptions annOptions;

public Builder(IndexRegistry indexRegistry)
public Builder(IndexRegistry indexRegistry, ANNOptions annOptions)
{
this.indexRegistry = indexRegistry;
this.annOptions = annOptions;
}

public RowFilter build()
{
return new RowFilter(current.build());
return new RowFilter(current.build(), annOptions);
}

public RowFilter buildFromRestrictions(StatementRestrictions restrictions, TableMetadata table, QueryOptions options, QueryState queryState)
public RowFilter buildFromRestrictions(StatementRestrictions restrictions,
TableMetadata table,
QueryOptions options,
QueryState queryState)
{
FilterElement root = doBuild(restrictions, table, options);

if (Guardrails.queryFilters.enabled(queryState))
Guardrails.queryFilters.guard(root.numFilteredValues(), "Select query", false, queryState);

return new RowFilter(root);
return new RowFilter(root, annOptions);
}

private FilterElement doBuild(StatementRestrictions restrictions, TableMetadata table, QueryOptions options)
Expand Down Expand Up @@ -386,7 +408,7 @@ public void addAllAsConjunction(Consumer<Builder> addToRowFilterDelegate)
{
// If we're in disjunction mode, we must not pass the current builder to addToRowFilter.
// We create a new conjunction sub-builder instead and add all expressions there.
var builder = new Builder(indexRegistry);
var builder = new Builder(indexRegistry, annOptions);
addToRowFilterDelegate.accept(builder);

if (builder.current.expressions.size() == 1 && builder.current.children.isEmpty())
Expand Down Expand Up @@ -1725,19 +1747,22 @@ public void serialize(RowFilter filter, DataOutputPlus out, int version) throws
{
out.writeBoolean(false); // Old "is for thrift" boolean
FilterElement.serializer.serialize(filter.root, out, version);
ANNOptions.serializer.serialize(filter.annOptions, out, version);
}

public RowFilter deserialize(DataInputPlus in, int version, TableMetadata metadata) throws IOException
{
in.readBoolean(); // Unused
FilterElement operation = FilterElement.serializer.deserialize(in, version, metadata);
return new RowFilter(operation);
ANNOptions annOptions = ANNOptions.serializer.deserialize(in, version);
return new RowFilter(operation, annOptions);
}

public long serializedSize(RowFilter filter, int version)
{
return 1 // unused boolean
+ FilterElement.serializer.serializedSize(filter.root, version);
+ FilterElement.serializer.serializedSize(filter.root, version)
+ ANNOptions.serializer.serializedSize(filter.annOptions, version);
}
}
}
Loading

0 comments on commit 42a25dc

Please sign in to comment.