Skip to content

Commit

Permalink
[client-common] Added safeguard for compressor (#1307)
Browse files Browse the repository at this point in the history
* [client-common] Added safeguard for compressor

Today, the `compress`/`decompress` can still be invoked
even the compressor is closed already and for zstd based
compressor, it would crash.
This PR add some safeguard and fail fast if the compressor
is already closed.

* Fixed integration test failures

* Minor tweak

* Added a unit test

* Fixed minor comment

* Skipped locking for NoopCompressor
  • Loading branch information
gaojieliu authored Nov 15, 2024
1 parent d125933 commit 06390c8
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public GzipCompressor() {
}

@Override
public byte[] compress(byte[] data) throws IOException {
protected byte[] compressInternal(byte[] data) throws IOException {
ReusableGzipOutputStream out = gzipPool.getReusableGzipOutputStream();
try {
out.writeHeader();
Expand All @@ -37,7 +37,7 @@ public byte[] compress(byte[] data) throws IOException {
}

@Override
public void close() throws IOException {
protected void closeInternal() throws IOException {
try {
gzipPool.close();
} catch (Exception e) {
Expand All @@ -47,7 +47,7 @@ public void close() throws IOException {
}

@Override
public ByteBuffer compress(ByteBuffer data, int startPositionOfOutput) throws IOException {
protected ByteBuffer compressInternal(ByteBuffer data, int startPositionOfOutput) throws IOException {
/**
* N.B.: We initialize the size of buffer in this output stream at the size of the deflated payload, which is not
* ideal, but not necessarily bad either. The assumption is that GZIP usually doesn't compress our payloads that
Expand All @@ -74,7 +74,7 @@ public ByteBuffer compress(ByteBuffer data, int startPositionOfOutput) throws IO
}

@Override
public ByteBuffer decompress(ByteBuffer data) throws IOException {
protected ByteBuffer decompressInternal(ByteBuffer data) throws IOException {
if (data.hasRemaining()) {
if (data.hasArray()) {
return decompress(data.array(), data.position(), data.remaining());
Expand All @@ -89,14 +89,14 @@ public ByteBuffer decompress(ByteBuffer data) throws IOException {
}

@Override
public ByteBuffer decompress(byte[] data, int offset, int length) throws IOException {
protected ByteBuffer decompressInternal(byte[] data, int offset, int length) throws IOException {
try (InputStream gis = decompress(new ByteArrayInputStream(data, offset, length))) {
return ByteBuffer.wrap(IOUtils.toByteArray(gis));
}
}

@Override
public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int length, int schemaHeader)
protected ByteBuffer decompressAndPrependSchemaHeaderInternal(byte[] data, int offset, int length, int schemaHeader)
throws IOException {
byte[] decompressedByteArray;
try (InputStream gis = decompress(new ByteArrayInputStream(data, offset, length))) {
Expand All @@ -111,7 +111,7 @@ public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int
}

@Override
public InputStream decompress(InputStream inputStream) throws IOException {
protected InputStream decompressInternal(InputStream inputStream) throws IOException {
return new GZIPInputStream(inputStream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import java.nio.ByteBuffer;


/**
* Locking is not necessary for {@link NoopCompressor}, so this class overrides all the public APIs to avoid locking.
*/
public class NoopCompressor extends VeniceCompressor {
public NoopCompressor() {
super(CompressionStrategy.NO_OP);
Expand All @@ -16,6 +19,11 @@ public byte[] compress(byte[] data) throws IOException {
return data;
}

@Override
protected byte[] compressInternal(byte[] data) throws IOException {
throw new UnsupportedOperationException("compressInternal");
}

@Override
public ByteBuffer compress(ByteBuffer data, int startPositionOfOutput) throws IOException {
if (startPositionOfOutput != 0) {
Expand All @@ -24,6 +32,11 @@ public ByteBuffer compress(ByteBuffer data, int startPositionOfOutput) throws IO
return data;
}

@Override
protected ByteBuffer compressInternal(ByteBuffer src, int startPositionOfOutput) throws IOException {
throw new UnsupportedOperationException("compressInternal");
}

@Override
public int hashCode() {
return super.hashCode();
Expand All @@ -34,11 +47,21 @@ public ByteBuffer decompress(ByteBuffer data) throws IOException {
return data;
}

@Override
protected ByteBuffer decompressInternal(ByteBuffer data) throws IOException {
throw new UnsupportedOperationException("decompressInternal");
}

@Override
public ByteBuffer decompress(byte[] data, int offset, int length) throws IOException {
return ByteBuffer.wrap(data, offset, length);
}

@Override
protected ByteBuffer decompressInternal(byte[] data, int offset, int length) throws IOException {
throw new UnsupportedOperationException("decompressInternal");
}

@Override
public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int length, int schemaHeader)
throws IOException {
Expand All @@ -50,11 +73,32 @@ public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int
return bb;
}

@Override
protected ByteBuffer decompressAndPrependSchemaHeaderInternal(byte[] data, int offset, int length, int schemaHeader)
throws IOException {
throw new UnsupportedOperationException("decompressAndPrependSchemaHeaderInternal");
}

@Override
public InputStream decompress(InputStream inputStream) throws IOException {
return inputStream;
}

@Override
protected InputStream decompressInternal(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException("decompressInternal");
}

@Override
public void close() throws IOException {
// do nothing
}

@Override
protected void closeInternal() throws IOException {
throw new UnsupportedOperationException("closeInternal");
}

@Override
public boolean equals(Object o) {
if (o == this) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,103 @@
package com.linkedin.venice.compression;

import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.utils.ByteUtils;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.concurrent.locks.ReentrantReadWriteLock;


public abstract class VeniceCompressor implements Closeable {
protected static final int SCHEMA_HEADER_LENGTH = ByteUtils.SIZE_OF_INT;
private final CompressionStrategy compressionStrategy;
private boolean isClosed = false;
/**
* To avoid the race condition between 'compress'/'decompress' operation and 'close'.
*/
private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();

protected VeniceCompressor(CompressionStrategy compressionStrategy) {
this.compressionStrategy = compressionStrategy;
}

public abstract byte[] compress(byte[] data) throws IOException;
interface CompressionRunnable<R> {
R run() throws IOException;
}

private <R> R executeWithSafeGuard(CompressionRunnable<R> runnable) throws IOException {
readWriteLock.readLock().lock();
try {
if (isClosed) {
throw new VeniceException("Compressor for " + getCompressionStrategy() + " has been closed");
}
return runnable.run();
} finally {
readWriteLock.readLock().unlock();
}
}

public byte[] compress(byte[] data) throws IOException {
return executeWithSafeGuard(() -> compressInternal(data));
}

public abstract ByteBuffer compress(ByteBuffer src, int startPositionOfOutput) throws IOException;
protected abstract byte[] compressInternal(byte[] data) throws IOException;

public abstract ByteBuffer decompress(ByteBuffer data) throws IOException;
public ByteBuffer compress(ByteBuffer src, int startPositionOfOutput) throws IOException {
return executeWithSafeGuard(() -> compressInternal(src, startPositionOfOutput));
}

public abstract ByteBuffer decompress(byte[] data, int offset, int length) throws IOException;
protected abstract ByteBuffer compressInternal(ByteBuffer src, int startPositionOfOutput) throws IOException;

public ByteBuffer decompress(ByteBuffer data) throws IOException {
return executeWithSafeGuard(() -> decompressInternal(data));
}

protected abstract ByteBuffer decompressInternal(ByteBuffer data) throws IOException;

public ByteBuffer decompress(byte[] data, int offset, int length) throws IOException {
return executeWithSafeGuard(() -> decompressInternal(data, offset, length));
}

protected abstract ByteBuffer decompressInternal(byte[] data, int offset, int length) throws IOException;

/**
* This method tries to decompress data and maybe prepend the schema header.
* The returned ByteBuffer will be backed by byte array that starts with schema header, followed by the
* decompressed data. The ByteBuffer will be positioned at the beginning of the decompressed data and the remaining of
* the ByteBuffer will be the length of the decompressed data.
*/
public abstract ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int length, int schemaHeader)
throws IOException;
public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int length, int schemaHeader)
throws IOException {
return executeWithSafeGuard(() -> decompressAndPrependSchemaHeaderInternal(data, offset, length, schemaHeader));
}

protected abstract ByteBuffer decompressAndPrependSchemaHeaderInternal(
byte[] data,
int offset,
int length,
int schemaHeader) throws IOException;

public CompressionStrategy getCompressionStrategy() {
return compressionStrategy;
}

public abstract InputStream decompress(InputStream inputStream) throws IOException;
public InputStream decompress(InputStream inputStream) throws IOException {
return executeWithSafeGuard(() -> decompressInternal(inputStream));
}

protected abstract InputStream decompressInternal(InputStream inputStream) throws IOException;

public void close() throws IOException {
readWriteLock.writeLock().lock();
try {
isClosed = true;
closeInternal();
} finally {
readWriteLock.writeLock().unlock();
}
}

protected abstract void closeInternal() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ public ZstdWithDictCompressor(final byte[] dictionary, int level) {
}

@Override
public byte[] compress(byte[] data) {
protected byte[] compressInternal(byte[] data) {
return compressor.get().compress(data);
}

@Override
public ByteBuffer compress(ByteBuffer data, int startPositionOfOutput) throws IOException {
protected ByteBuffer compressInternal(ByteBuffer data, int startPositionOfOutput) throws IOException {
long maxDstSize = Zstd.compressBound(data.remaining());
if (maxDstSize + startPositionOfOutput > Integer.MAX_VALUE) {
throw new ZstdException(Zstd.errGeneric(), "Max output size is greater than Integer.MAX_VALUE");
Expand Down Expand Up @@ -87,7 +87,7 @@ public ByteBuffer compress(ByteBuffer data, int startPositionOfOutput) throws IO
}

@Override
public ByteBuffer decompress(ByteBuffer data) throws IOException {
protected ByteBuffer decompressInternal(ByteBuffer data) throws IOException {
if (data.hasRemaining()) {
if (data.hasArray()) {
return decompress(data.array(), data.position(), data.remaining());
Expand All @@ -107,7 +107,7 @@ public ByteBuffer decompress(ByteBuffer data) throws IOException {
}

@Override
public ByteBuffer decompress(byte[] data, int offset, int length) throws IOException {
protected ByteBuffer decompressInternal(byte[] data, int offset, int length) throws IOException {
int expectedSize = validateExpectedDecompressedSize(Zstd.decompressedSize(data, offset, length));
ByteBuffer returnedData = ByteBuffer.allocate(expectedSize);
int actualSize = decompressor.get()
Expand All @@ -124,7 +124,7 @@ public ByteBuffer decompress(byte[] data, int offset, int length) throws IOExcep
}

@Override
public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int length, int schemaHeader)
protected ByteBuffer decompressAndPrependSchemaHeaderInternal(byte[] data, int offset, int length, int schemaHeader)
throws IOException {
int expectedDecompressedDataSize = validateExpectedDecompressedSize(Zstd.decompressedSize(data, offset, length));

Expand All @@ -138,12 +138,12 @@ public ByteBuffer decompressAndPrependSchemaHeader(byte[] data, int offset, int
}

@Override
public InputStream decompress(InputStream inputStream) throws IOException {
protected InputStream decompressInternal(InputStream inputStream) throws IOException {
return new ZstdInputStream(inputStream).setDict(this.dictDecompress);
}

@Override
public void close() throws IOException {
protected void closeInternal() throws IOException {
this.compressor.close();
this.decompressor.close();
IOUtils.closeQuietly(this.dictCompress);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package com.linkedin.venice.compression;

import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.expectThrows;

import com.github.luben.zstd.Zstd;
import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.utils.ByteUtils;
import com.linkedin.venice.utils.TestUtils;
import com.linkedin.venice.utils.Time;
Expand All @@ -14,6 +19,7 @@
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.testng.Assert;
Expand Down Expand Up @@ -173,7 +179,7 @@ private enum SourceDataType {

@Test
public void testZSTDThrowsExceptionOnNullDictionary() {
Assert.assertThrows(
assertThrows(
() -> new CompressorFactory()
.createVersionSpecificCompressorIfNotExist(CompressionStrategy.ZSTD_WITH_DICT, "foo_v1", null));
}
Expand Down Expand Up @@ -205,4 +211,15 @@ public void testCompressorEqual() {
}
}
}

@Test
public void testCompressorClose() throws IOException {
VeniceCompressor compressor = new ZstdWithDictCompressor("abc".getBytes(), Zstd.maxCompressionLevel());
String largePayload = RandomStringUtils.randomAlphabetic(500000);
compressor.compress(largePayload.getBytes());
compressor.close();
VeniceException exception =
expectThrows(VeniceException.class, () -> compressor.compress(ByteBuffer.wrap(largePayload.getBytes()), 4));
assertTrue(exception.getMessage().contains("has been closed"));
}
}

0 comments on commit 06390c8

Please sign in to comment.