Skip to content

Commit

Permalink
[server][common] Fixed bug in AA/WC parallel processing support (#1252)
Browse files Browse the repository at this point in the history
* [server][common] Fixed bug in AA/WC parallel processing support

This PR fixed the following issues:
1. AASIT should pass non-null `KeyLevelLocksManager` to `IngestionBatchProcessor`, otherwise,
   race condition will happen.
2. Fixed the locking order in `IngestionBatchProcessor` to avoid deadlock.
3. Updated `SparseConcurrentList#computeIfAbsent` to skip adjust list size if the computed result is `null`.
  • Loading branch information
gaojieliu authored Oct 30, 2024
1 parent db45f57 commit d0ee623
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ public ActiveActiveStoreIngestionTask(
this.remoteIngestionRepairService = builder.getRemoteIngestionRepairService();
this.ingestionBatchProcessorLazy = Lazy.of(() -> {
if (!serverConfig.isAAWCWorkloadParallelProcessingEnabled()) {
LOGGER.info("AA/WC workload parallel processing enabled is false");
LOGGER.info("AA/WC workload parallel processing is disabled for store version: {}", getKafkaVersionTopic());
return null;
}
LOGGER.info("AA/WC workload parallel processing enabled is true");
LOGGER.info("AA/WC workload parallel processing is enabled for store version: {}", getKafkaVersionTopic());
return new IngestionBatchProcessor(
kafkaVersionTopic,
parallelProcessingThreadPool,
null,
keyLevelLocksManager.get(),
this::processActiveActiveMessage,
isWriteComputationEnabled,
isActiveActiveReplicationEnabled(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;


Expand All @@ -27,6 +28,8 @@
* resources to speed up the leader ingestion.
*/
public class IngestionBatchProcessor {
private static final TreeMap EMPTY_TREE_MAP = new TreeMap();

interface ProcessingFunction {
PubSubMessageProcessedResult apply(
PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long> consumerRecord,
Expand Down Expand Up @@ -71,32 +74,41 @@ public IngestionBatchProcessor(
this.version = Version.parseVersionFromKafkaTopicName(storeVersionName);
}

// For testing
KeyLevelLocksManager getLockManager() {
return this.lockManager;
}

/**
* When {@link #lockManager} is not null, this function will try to lock all the keys
* (except Control Messages) passed by the params.
*/
public List<ReentrantLock> lockKeys(List<PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long>> records) {
public NavigableMap<ByteArrayKey, ReentrantLock> lockKeys(
List<PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long>> records) {
if (lockManager != null) {
List<ReentrantLock> locks = new ArrayList<>(records.size());
/**
* Need to use a {@link TreeMap} to make sure the locking will be executed in a deterministic order, otherwise
* deadlock can happen.
* Considering there could be multiple consumers, which are executing this function concurrently, and if they
* are trying to lock the same set of keys with different orders, deadlock can happen.
*/
TreeMap<ByteArrayKey, ReentrantLock> keyLockMap = new TreeMap<>();
records.forEach(r -> {
if (!r.getKey().isControlMessage()) {
ReentrantLock lock = lockManager.acquireLockByKey(ByteArrayKey.wrap(r.getKey().getKey()));
locks.add(lock);
lock.lock();
keyLockMap.computeIfAbsent(ByteArrayKey.wrap(r.getKey().getKey()), k -> lockManager.acquireLockByKey(k));
}
});
return locks;
keyLockMap.forEach((k, v) -> v.lock());
return keyLockMap;
}
return Collections.emptyList();
return Collections.emptyNavigableMap();
}

public void unlockKeys(List<PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long>> records, List<ReentrantLock> locks) {
public void unlockKeys(NavigableMap<ByteArrayKey, ReentrantLock> keyLockMap) {
if (lockManager != null) {
locks.forEach(lock -> lock.unlock());
records.forEach(r -> {
if (!r.getKey().isControlMessage()) {
lockManager.releaseLock(ByteArrayKey.wrap(r.getKey().getKey()));
}
keyLockMap.descendingMap().forEach((key, lock) -> {
lock.unlock();
lockManager.releaseLock(key);
});
}
}
Expand All @@ -123,38 +135,35 @@ public List<PubSubMessageProcessedResultWrapper<KafkaKey, KafkaMessageEnvelope,
if (records.isEmpty()) {
return Collections.emptyList();
}
AtomicBoolean isAllMessagesFromRTTopic = new AtomicBoolean(true);
boolean isAllMessagesFromRTTopic = true;
List<PubSubMessageProcessedResultWrapper<KafkaKey, KafkaMessageEnvelope, Long>> resultList =
new ArrayList<>(records.size());
records.forEach(r -> {
resultList.add(new PubSubMessageProcessedResultWrapper<>(r));
if (!r.getTopicPartition().getPubSubTopic().isRealTime()) {
isAllMessagesFromRTTopic.set(false);
}
});
if (!isWriteComputationEnabled && !isActiveActiveReplicationEnabled) {
return resultList;
}
// Only handle records from the real-time topic
if (!isAllMessagesFromRTTopic.get()) {
return resultList;
}

/**
* We would like to process the messages belonging to the same key sequentially to avoid race conditions.
*/
int totalNumOfRecords = 0;
Map<ByteArrayKey, List<PubSubMessageProcessedResultWrapper<KafkaKey, KafkaMessageEnvelope, Long>>> keyGroupMap =
new HashMap<>(records.size());

for (PubSubMessageProcessedResultWrapper<KafkaKey, KafkaMessageEnvelope, Long> r: resultList) {
PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long> message = r.getMessage();
if (!message.getKey().isControlMessage()) {
for (PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long> message: records) {
if (!message.getTopicPartition().getPubSubTopic().isRealTime()) {
isAllMessagesFromRTTopic = false;
}
PubSubMessageProcessedResultWrapper resultWrapper = new PubSubMessageProcessedResultWrapper<>(message);
resultList.add(resultWrapper);
if (!message.getKey().isControlMessage() && isAllMessagesFromRTTopic) {
ByteArrayKey byteArrayKey = ByteArrayKey.wrap(message.getKey().getKey());
keyGroupMap.computeIfAbsent(byteArrayKey, (ignored) -> new ArrayList<>()).add(r);
keyGroupMap.computeIfAbsent(byteArrayKey, (ignored) -> new ArrayList<>()).add(resultWrapper);
totalNumOfRecords++;
}
}
if (!isWriteComputationEnabled && !isActiveActiveReplicationEnabled) {
return resultList;
}
// Only handle records from the real-time topic
if (!isAllMessagesFromRTTopic) {
return resultList;
}
aggVersionedIngestionStats
.recordBatchProcessingRequest(storeName, version, totalNumOfRecords, System.currentTimeMillis());
hostLevelIngestionStats.recordBatchProcessingRequest(totalNumOfRecords);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ public LeaderFollowerStoreIngestionTask(
serverConfig.isComputeFastAvroEnabled());
this.ingestionBatchProcessingLazy = Lazy.of(() -> {
if (!serverConfig.isAAWCWorkloadParallelProcessingEnabled()) {
LOGGER.info("AA/WC workload parallel processing enabled is false");
LOGGER.info("AA/WC workload parallel processing is disabled for store version: {}", getKafkaVersionTopic());
return null;
}
LOGGER.info("AA/WC workload parallel processing enabled is true");
LOGGER.info("AA/WC workload parallel processing is enabled for store version: {}", getKafkaVersionTopic());
return new IngestionBatchProcessor(
kafkaVersionTopic,
parallelProcessingThreadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public PubSubMessageProcessedResult getProcessedResult() {
return processedResult;
}

public void setProcessedResult(PubSubMessageProcessedResult transformedResult) {
this.processedResult = transformedResult;
public void setProcessedResult(PubSubMessageProcessedResult processedResult) {
this.processedResult = processedResult;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.linkedin.davinci.store.StoragePartitionConfig;
import com.linkedin.davinci.store.cache.backend.ObjectCacheBackend;
import com.linkedin.davinci.store.record.ValueRecord;
import com.linkedin.davinci.utils.ByteArrayKey;
import com.linkedin.davinci.utils.ChunkAssembler;
import com.linkedin.davinci.validation.KafkaDataIntegrityValidator;
import com.linkedin.davinci.validation.PartitionTracker;
Expand Down Expand Up @@ -125,6 +126,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
Expand Down Expand Up @@ -1262,7 +1264,7 @@ protected void produceToStoreBufferServiceOrKafkaInBatch(
* Process records batch by batch.
*/
for (List<PubSubMessage<KafkaKey, KafkaMessageEnvelope, Long>> batch: batches) {
List<ReentrantLock> locks = ingestionBatchProcessor.lockKeys(batch);
NavigableMap<ByteArrayKey, ReentrantLock> keyLockMap = ingestionBatchProcessor.lockKeys(batch);
try {
long beforeProcessingPerRecordTimestampNs = System.nanoTime();
List<PubSubMessageProcessedResultWrapper<KafkaKey, KafkaMessageEnvelope, Long>> processedResults =
Expand All @@ -1288,7 +1290,7 @@ protected void produceToStoreBufferServiceOrKafkaInBatch(
elapsedTimeForPuttingIntoQueue);
}
} finally {
ingestionBatchProcessor.unlockKeys(batch, locks);
ingestionBatchProcessor.unlockKeys(keyLockMap);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package com.linkedin.davinci.utils;

import com.linkedin.venice.utils.ByteUtils;
import java.util.Arrays;


/**
* A low overhead immutable container of byte[] suitable for use as a map key.
*/
public class ByteArrayKey {
public class ByteArrayKey implements Comparable<ByteArrayKey> {
private final byte[] content;
private final int hashCode;

Expand All @@ -31,6 +32,10 @@ public boolean equals(Object o) {
return Arrays.equals(content, that.content);
}

public byte[] getContent() {
return this.content;
}

@Override
public int hashCode() {
return this.hashCode;
Expand All @@ -39,4 +44,9 @@ public int hashCode() {
public static ByteArrayKey wrap(byte[] content) {
return new ByteArrayKey(content);
}

@Override
public int compareTo(ByteArrayKey o) {
return ByteUtils.compare(content, o.content);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
Expand Down Expand Up @@ -31,6 +32,7 @@
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.NavigableMap;
import java.util.concurrent.locks.ReentrantLock;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -116,21 +118,38 @@ public void lockKeysTest() {
true,
mock(AggVersionedIngestionStats.class),
mock(HostLevelIngestionStats.class));
List<ReentrantLock> locks = batchProcessor.lockKeys(Arrays.asList(rtMessage1, rtMessage2));
/**
* Switch the input order to make sure the `lockKeys` function would sort them when locking.
*/
NavigableMap<ByteArrayKey, ReentrantLock> keyLockMap =
batchProcessor.lockKeys(Arrays.asList(rtMessage2, rtMessage1));
verify(mockKeyLevelLocksManager).acquireLockByKey(ByteArrayKey.wrap(key1));
verify(mockKeyLevelLocksManager).acquireLockByKey(ByteArrayKey.wrap(key2));
verify(lockForKey1).lock();
verify(lockForKey2).lock();
assertEquals(locks.get(0), lockForKey1);
assertEquals(locks.get(1), lockForKey2);
// Verify the order
ReentrantLock[] locks = keyLockMap.values().toArray(new ReentrantLock[0]);
assertEquals(locks[0], lockForKey1);
assertEquals(locks[1], lockForKey2);

// unlock test
batchProcessor.unlockKeys(Arrays.asList(rtMessage1, rtMessage2), locks);
batchProcessor.unlockKeys(keyLockMap);

verify(lockForKey1).unlock();
verify(lockForKey2).unlock();
verify(mockKeyLevelLocksManager).releaseLock(ByteArrayKey.wrap(key1));
verify(mockKeyLevelLocksManager).releaseLock(ByteArrayKey.wrap(key2));

// Duplicate messages in the batch
keyLockMap = batchProcessor.lockKeys(Arrays.asList(rtMessage1, rtMessage2, rtMessage1));
verify(mockKeyLevelLocksManager, times(2)).acquireLockByKey(ByteArrayKey.wrap(key1));
verify(mockKeyLevelLocksManager, times(2)).acquireLockByKey(ByteArrayKey.wrap(key2));
verify(lockForKey1, times(2)).lock();
verify(lockForKey2, times(2)).lock();
// Verify the order
locks = keyLockMap.values().toArray(new ReentrantLock[0]);
assertEquals(locks[0], lockForKey1);
assertEquals(locks[1], lockForKey2);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,14 @@ public void testActiveActiveStoreIsReadyToServe(HybridConfig hybridConfig, NodeT
Optional.empty(),
null);

if (hybridConfig.equals(HYBRID) && nodeType.equals(LEADER) && isAaWCParallelProcessingEnabled()) {
assertTrue(storeIngestionTaskUnderTest instanceof ActiveActiveStoreIngestionTask);
ActiveActiveStoreIngestionTask activeActiveStoreIngestionTask =
(ActiveActiveStoreIngestionTask) storeIngestionTaskUnderTest;
assertNotNull(activeActiveStoreIngestionTask.getIngestionBatchProcessor());
assertNotNull(activeActiveStoreIngestionTask.getIngestionBatchProcessor().getLockManager());
}

String rtTopicName = Version.composeRealTimeTopic(mockStore.getName());
PubSubTopic rtTopic = pubSubTopicRepository.getTopic(rtTopicName);
TopicSwitch topicSwitchWithMultipleSourceKafkaServers = new TopicSwitch();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ public E computeIfAbsent(int index, IntFunction<? extends E> mappingFunction) {
element = get(index);
if (element == null) {
element = mappingFunction.apply(index);
/**
* Don't update the list if the computed result is `null`.
*/
if (element == null) {
return null;
}
/**
* It's important NOT to call {@link #handleSizeDuringMutation(Object, Object)} since {@link #set(int, Object)}
* already calls it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ public void testNonNullSize() {
assertEquals(scl.values().size(), scl.nonNullSize());
assertFalse(scl.isEmpty());

// Compute if absent for an unpopulated index with computed result as `null`.
scl.computeIfAbsent(40, k -> null);
assertEquals(scl.size(), 8);

// Go back to the initial state...
scl.clear();
assertEquals(scl.size(), 0);
Expand Down

0 comments on commit d0ee623

Please sign in to comment.