diff --git a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java index 2b22c5562..5d052d28a 100644 --- a/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java +++ b/google-cloud-firestore/src/test/java/com/google/cloud/firestore/it/ITQueryCountTest.java @@ -20,7 +20,6 @@ import static com.google.common.truth.Truth.assertThat; import static java.util.Collections.singletonMap; import static org.junit.Assert.assertThrows; -import static org.junit.Assume.assumeTrue; import com.google.api.core.ApiFuture; import com.google.auto.value.AutoValue; @@ -36,8 +35,8 @@ import com.google.cloud.firestore.QueryDocumentSnapshot; import com.google.cloud.firestore.TransactionOptions; import com.google.cloud.firestore.WriteBatch; -import com.google.cloud.firestore.WriteResult; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -45,6 +44,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -235,44 +235,41 @@ public void aggregateQueryShouldWorkInATransaction() throws Exception { @Test public void aggregateQueryInATransactionShouldLockTheCountedDocuments() throws Exception { - assumeTrue( - "Skip this test when running against production because " - + "it appears that production is failing to lock the counted documents b/248152832", - isRunningAgainstFirestoreEmulator()); + CollectionReference collection = createCollectionWithDocuments(7).collection(); + DocumentReference resultsDocument = createEmptyCollection().document(); - CollectionReference collection = createEmptyCollection(); - DocumentReference document = createDocumentInCollection(collection); CountDownLatch aggregateQueryExecutedSignal = new CountDownLatch(1); CountDownLatch transactionContinueSignal = new CountDownLatch(1); + AtomicInteger transactionInvokeCount = new AtomicInteger(0); ApiFuture transactionFuture = collection .getFirestore() .runTransaction( t -> { - t.get(collection.count()).get(); - aggregateQueryExecutedSignal.countDown(); - transactionContinueSignal.await(); + int invokeCount = transactionInvokeCount.getAndIncrement(); + long count = t.get(collection.count()).get().getCount(); + if (invokeCount == 0) { + aggregateQueryExecutedSignal.countDown(); + transactionContinueSignal.await(); + } + t.set(resultsDocument, ImmutableMap.of("count", count)); return null; }); - ExecutionException executionException; try { aggregateQueryExecutedSignal.await(); - ApiFuture documentSetTask = document.set(singletonMap("abc", "def")); - executionException = assertThrows(ExecutionException.class, documentSetTask::get); + // Add a document to the collection so the count retrieved in the transaction is stale. + collection.document().set(ImmutableMap.of("key", 42L)).get(); } finally { transactionContinueSignal.countDown(); } - assertThat(executionException) - .hasCauseThat() - .hasMessageThat() - .ignoringCase() - .contains("transaction lock timeout"); - // Wait for the transaction to finish. transactionFuture.get(); + + // Verify that the correct count was written in the transaction. + assertThat(resultsDocument.get().get().getLong("count")).isEqualTo(8); } @Test