Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QueuedStatementResource timeout for query submission #62

Open
wants to merge 2 commits into
base: hotfix-350
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
package io.prestosql.dispatcher;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.prestosql.client.QueryError;
import io.prestosql.client.QueryResults;
import io.prestosql.client.StatementStats;
import io.prestosql.execution.ExecutionFailureInfo;
import io.prestosql.execution.QueryManagerConfig;
import io.prestosql.execution.QueryState;
import io.prestosql.server.HttpRequestSessionContext;
import io.prestosql.server.ServerConfig;
Expand All @@ -35,8 +36,10 @@
import io.prestosql.spi.security.GroupProvider;
import io.prestosql.spi.security.Identity;

import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.DELETE;
Expand All @@ -58,19 +61,21 @@
import javax.ws.rs.core.UriInfo;

import java.net.URI;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static com.clearspring.analytics.util.Preconditions.checkState;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.MoreFutures.addTimeout;
import static io.airlift.concurrent.Threads.threadsNamed;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
import static io.prestosql.execution.QueryState.FAILED;
import static io.prestosql.execution.QueryState.QUEUED;
Expand Down Expand Up @@ -103,16 +108,16 @@ public class QueuedStatementResource
private final Executor responseExecutor;
private final ScheduledExecutorService timeoutExecutor;

private final ConcurrentMap<QueryId, Query> queries = new ConcurrentHashMap<>();
private final ScheduledExecutorService queryPurger = newSingleThreadScheduledExecutor(threadsNamed("dispatch-query-purger"));
private final boolean compressionEnabled;
private final QueryManager queryManager;

@Inject
public QueuedStatementResource(
GroupProvider groupProvider,
DispatchManager dispatchManager,
DispatchExecutor executor,
ServerConfig serverConfig)
ServerConfig serverConfig,
QueryManagerConfig queryManagerConfig)
{
this.groupProvider = requireNonNull(groupProvider, "groupProvider is null");
this.dispatchManager = requireNonNull(dispatchManager, "dispatchManager is null");
Expand All @@ -122,43 +127,20 @@ public QueuedStatementResource(
this.timeoutExecutor = requireNonNull(executor, "timeoutExecutor is null").getScheduledExecutor();
this.compressionEnabled = requireNonNull(serverConfig, "serverConfig is null").isQueryResultsCompressionEnabled();

queryPurger.scheduleWithFixedDelay(
() -> {
try {
// snapshot the queries before checking states to avoid registration race
for (Entry<QueryId, Query> entry : ImmutableSet.copyOf(queries.entrySet())) {
if (!entry.getValue().isSubmissionFinished()) {
continue;
}

// forget about this query if the query manager is no longer tracking it
if (!dispatchManager.isQueryRegistered(entry.getKey())) {
Query query = queries.remove(entry.getKey());
if (query != null) {
try {
query.destroy();
}
catch (Throwable e) {
// this catch clause is broad so query purger does not get stuck
log.warn(e, "Error destroying identity");
}
}
}
}
}
catch (Throwable e) {
log.warn(e, "Error removing old queries");
}
},
200,
200,
MILLISECONDS);
requireNonNull(queryManagerConfig, "queryManagerConfig is null");
queryManager = new QueryManager(queryManagerConfig.getClientTimeout());
}

@PostConstruct
public void start()
{
queryManager.initialize(dispatchManager);
}

@PreDestroy
public void stop()
{
queryPurger.shutdownNow();
queryManager.destroy();
}

@ResourceSecurity(AUTHENTICATED_USER)
Expand All @@ -174,18 +156,25 @@ public Response postStatement(
throw badRequest(BAD_REQUEST, "SQL statement is empty");
}

Query query = registerQuery(statement, servletRequest, httpHeaders);

return createQueryResultsResponse(query.getQueryResults(query.getLastToken(), uriInfo));
}

private Query registerQuery(String statement, HttpServletRequest servletRequest, HttpHeaders httpHeaders)
{
String remoteAddress = servletRequest.getRemoteAddr();
Optional<Identity> identity = Optional.ofNullable((Identity) servletRequest.getAttribute(AUTHENTICATED_IDENTITY));
MultivaluedMap<String, String> headers = httpHeaders.getRequestHeaders();

SessionContext sessionContext = new HttpRequestSessionContext(headers, remoteAddress, identity, groupProvider);
Query query = new Query(statement, sessionContext, dispatchManager);
queries.put(query.getQueryId(), query);
queryManager.registerQuery(query);

// let authentication filter know that identity lifecycle has been handed off
servletRequest.setAttribute(AUTHENTICATED_IDENTITY, null);

return createQueryResultsResponse(query.getQueryResults(query.getLastToken(), uriInfo), compressionEnabled);
return query;
}

@ResourceSecurity(PUBLIC)
Expand All @@ -202,25 +191,21 @@ public void getStatus(
{
Query query = getQuery(queryId, slug, token);

// wait for query to be dispatched, up to the wait timeout
ListenableFuture<?> futureStateChange = addTimeout(
query.waitForDispatched(),
() -> null,
WAIT_ORDERING.min(MAX_WAIT_TIME, maxWait),
timeoutExecutor);

// when state changes, fetch the next result
ListenableFuture<QueryResults> queryResultsFuture = Futures.transform(
futureStateChange,
ignored -> query.getQueryResults(token, uriInfo),
responseExecutor);

// transform to Response
ListenableFuture<Response> response = Futures.transform(
queryResultsFuture,
queryResults -> createQueryResultsResponse(queryResults, compressionEnabled),
directExecutor());
bindAsyncResponse(asyncResponse, response, responseExecutor);
ListenableFuture<Response> future = getStatus(query, token, maxWait, uriInfo);
bindAsyncResponse(asyncResponse, future, responseExecutor);
}

private ListenableFuture<Response> getStatus(Query query, long token, Duration maxWait, UriInfo uriInfo)
{
long waitMillis = WAIT_ORDERING.min(MAX_WAIT_TIME, maxWait).toMillis();

return FluentFuture.from(query.waitForDispatched())
// wait for query to be dispatched, up to the wait timeout
.withTimeout(waitMillis, MILLISECONDS, timeoutExecutor)
.catching(TimeoutException.class, ignored -> null, directExecutor())
// when state changes, fetch the next result
.transform(ignored -> query.getQueryResults(token, uriInfo), responseExecutor)
.transform(this::createQueryResultsResponse, directExecutor());
}

@ResourceSecurity(PUBLIC)
Expand All @@ -239,14 +224,14 @@ public Response cancelQuery(

private Query getQuery(QueryId queryId, String slug, long token)
{
Query query = queries.get(queryId);
Query query = queryManager.getQuery(queryId);
if (query == null || !query.getSlug().isValid(QUEUED_QUERY, slug, token)) {
throw badRequest(NOT_FOUND, "Query not found");
}
return query;
}

private static Response createQueryResultsResponse(QueryResults results, boolean compressionEnabled)
private Response createQueryResultsResponse(QueryResults results)
{
Response.ResponseBuilder builder = Response.ok(results);
if (!compressionEnabled) {
Expand Down Expand Up @@ -320,8 +305,9 @@ private static final class Query
private final Slug slug = Slug.createNew();
private final AtomicLong lastToken = new AtomicLong();

@GuardedBy("this")
private ListenableFuture<?> querySubmissionFuture;
private final long initTime = System.nanoTime();
private final AtomicReference<Boolean> submissionGate = new AtomicReference<>();
private final SettableFuture<Object> creationFuture = SettableFuture.create();

public Query(String query, SessionContext sessionContext, DispatchManager dispatchManager)
{
Expand All @@ -346,27 +332,38 @@ public long getLastToken()
return lastToken.get();
}

public synchronized boolean isSubmissionFinished()
public boolean tryAbandonSubmissionWithTimeout(Duration querySubmissionTimeout)
{
return Duration.nanosSince(initTime).compareTo(querySubmissionTimeout) >= 0 && submissionGate.compareAndSet(null, false);
}

public boolean isSubmissionAbandoned()
{
return Boolean.FALSE.equals(submissionGate.get());
}

public boolean isCreated()
{
return querySubmissionFuture != null && querySubmissionFuture.isDone();
return creationFuture.isDone();
}

private ListenableFuture<?> waitForDispatched()
{
// if query query submission has not finished, wait for it to finish
synchronized (this) {
if (querySubmissionFuture == null) {
querySubmissionFuture = dispatchManager.createQuery(queryId, slug, sessionContext, query);
}
if (!querySubmissionFuture.isDone()) {
return querySubmissionFuture;
}
submitIfNeeded();
if (!creationFuture.isDone()) {
return nonCancellationPropagating(creationFuture);
}

// otherwise, wait for the query to finish
return dispatchManager.waitForDispatched(queryId);
}

private void submitIfNeeded()
{
if (submissionGate.compareAndSet(null, true)) {
creationFuture.setFuture(dispatchManager.createQuery(queryId, slug, sessionContext, query));
}
}

public QueryResults getQueryResults(long token, UriInfo uriInfo)
{
long lastToken = this.lastToken.get();
Expand All @@ -377,14 +374,12 @@ public QueryResults getQueryResults(long token, UriInfo uriInfo)
// advance (or stay at) the token
this.lastToken.compareAndSet(lastToken, token);

synchronized (this) {
// if query submission has not finished, return simple empty result
if (querySubmissionFuture == null || !querySubmissionFuture.isDone()) {
return createQueryResults(
token + 1,
uriInfo,
DispatchInfo.queued(NO_DURATION, NO_DURATION));
}
// if query submission has not finished, return simple empty result
if (!creationFuture.isDone()) {
return createQueryResults(
token + 1,
uriInfo,
DispatchInfo.queued(NO_DURATION, NO_DURATION));
}

Optional<DispatchInfo> dispatchInfo = dispatchManager.getDispatchInfo(queryId);
Expand All @@ -398,9 +393,9 @@ public QueryResults getQueryResults(long token, UriInfo uriInfo)
return createQueryResults(token + 1, uriInfo, dispatchInfo.get());
}

public synchronized void cancel()
public void cancel()
{
querySubmissionFuture.addListener(() -> dispatchManager.cancelQuery(queryId), directExecutor());
creationFuture.addListener(() -> dispatchManager.cancelQuery(queryId), directExecutor());
}

public void destroy()
Expand Down Expand Up @@ -468,4 +463,82 @@ private QueryError toQueryError(ExecutionFailureInfo executionFailureInfo)
executionFailureInfo.toFailureInfo());
}
}

@ThreadSafe
private static class QueryManager
{
private final ConcurrentMap<QueryId, Query> queries = new ConcurrentHashMap<>();
private final ScheduledExecutorService scheduledExecutorService = newSingleThreadScheduledExecutor(daemonThreadsNamed("drain-state-query-manager"));

private final Duration querySubmissionTimeout;

public QueryManager(Duration querySubmissionTimeout)
{
this.querySubmissionTimeout = requireNonNull(querySubmissionTimeout, "querySubmissionTimeout is null");
}

public void initialize(DispatchManager dispatchManager)
{
scheduledExecutorService.scheduleWithFixedDelay(() -> syncWith(dispatchManager), 200, 200, MILLISECONDS);
}

public void destroy()
{
scheduledExecutorService.shutdownNow();
}

private void syncWith(DispatchManager dispatchManager)
{
queries.forEach((queryId, query) -> {
if (shouldBePurged(dispatchManager, query)) {
removeQuery(queryId);
}
});
}

private boolean shouldBePurged(DispatchManager dispatchManager, Query query)
{
if (query.isSubmissionAbandoned()) {
// Query submission was explicitly abandoned
return true;
}
if (query.tryAbandonSubmissionWithTimeout(querySubmissionTimeout)) {
// Query took too long to be submitted by the client
return true;
}
if (query.isCreated() && !dispatchManager.isQueryRegistered(query.getQueryId())) {
// Query was created in the DispatchManager, and DispatchManager has already purged the query
return true;
}
return false;
}

private void removeQuery(QueryId queryId)
{
Optional.ofNullable(queries.remove(queryId))
.ifPresent(QueryManager::destroyQuietly);
}

private static void destroyQuietly(Query query)
{
try {
query.destroy();
}
catch (Throwable t) {
log.error(t, "Error destroying query");
}
}

public void registerQuery(Query query)
{
Query existingQuery = queries.putIfAbsent(query.getQueryId(), query);
checkState(existingQuery == null, "Query already registered");
}

@Nullable
public Query getQuery(QueryId queryId)
{
return queries.get(queryId);
}
}
}