From 7e9e0c5dd009ac440170a6d3c9e674bafdb0e4ff Mon Sep 17 00:00:00 2001 From: tore-statsig <74584483+tore-statsig@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:02:48 -0700 Subject: [PATCH] test: multi initialize/initializeasync (#247) * test: multi initialize/initializeasync * fix: cleanup test, add long timeout case * fix: test * fix: lints * feat: make initialize and initializeAsync threadsafe * fix: unset in shutdown * fix: eb test * fix: lint * fix: async test code --- .../com/statsig/androidsdk/ErrorBoundary.kt | 5 +- .../com/statsig/androidsdk/StatsigClient.kt | 10 +- .../statsig/androidsdk/ErrorBoundaryTest.kt | 11 +- .../StatsigLongInitializationTimeoutTest.kt | 91 ++++++++++ .../StatsigMultipleInitializeTest.kt | 170 ++++++++++++++++++ 5 files changed, 272 insertions(+), 15 deletions(-) create mode 100644 src/test/java/com/statsig/androidsdk/StatsigLongInitializationTimeoutTest.kt create mode 100644 src/test/java/com/statsig/androidsdk/StatsigMultipleInitializeTest.kt diff --git a/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt b/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt index b7dc7b9..7f1eef7 100644 --- a/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt +++ b/src/main/java/com/statsig/androidsdk/ErrorBoundary.kt @@ -1,10 +1,7 @@ package com.statsig.androidsdk import com.google.gson.Gson -import kotlinx.coroutines.CoroutineExceptionHandler -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.launch +import kotlinx.coroutines.* import java.io.DataOutputStream import java.lang.RuntimeException import java.net.HttpURLConnection diff --git a/src/main/java/com/statsig/androidsdk/StatsigClient.kt b/src/main/java/com/statsig/androidsdk/StatsigClient.kt index 2908854..37dccbe 100644 --- a/src/main/java/com/statsig/androidsdk/StatsigClient.kt +++ b/src/main/java/com/statsig/androidsdk/StatsigClient.kt @@ -44,6 +44,7 @@ class StatsigClient() : LifecycleEventListener { private var dispatcherProvider = CoroutineDispatcherProvider() private var initialized = AtomicBoolean(false) private var isBootstrapped = AtomicBoolean(false) + private var isInitializing = AtomicBoolean(false) @VisibleForTesting internal lateinit var statsigNetwork: StatsigNetwork @@ -71,7 +72,7 @@ class StatsigClient() : LifecycleEventListener { callback: IStatsigCallback? = null, options: StatsigOptions = StatsigOptions(), ) { - if (isInitialized()) { + if (isInitializing.getAndSet(true)) { return } errorBoundary.setKey(sdkKey) @@ -121,7 +122,7 @@ class StatsigClient() : LifecycleEventListener { user: StatsigUser? = null, options: StatsigOptions = StatsigOptions(), ): InitializationDetails? { - if (this@StatsigClient.isInitialized()) { + if (isInitializing.getAndSet(true)) { return null } errorBoundary.setKey(sdkKey) @@ -1034,13 +1035,14 @@ class StatsigClient() : LifecycleEventListener { } private suspend fun shutdownImpl() { + initialized.set(false) pollingJob?.cancel() logger.shutdown() lifecycleListener.shutdown() - initialized = AtomicBoolean() - isBootstrapped = AtomicBoolean() + isBootstrapped.set(false) errorBoundary = ErrorBoundary() statsigJob = SupervisorJob() + isInitializing.set(false) } private fun logEndDiagnostics(success: Boolean, context: ContextType, initResponse: InitializeResponse?) { diff --git a/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt b/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt index e2706de..70a20e5 100644 --- a/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt +++ b/src/test/java/com/statsig/androidsdk/ErrorBoundaryTest.kt @@ -6,6 +6,7 @@ import com.github.tomakehurst.wiremock.junit.WireMockRule import io.mockk.mockk import io.mockk.unmockkAll import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runBlockingTest import org.junit.After import org.junit.Assert.* import org.junit.Before @@ -19,6 +20,7 @@ class ErrorBoundaryTest { @Before internal fun setup() { + TestUtil.mockDispatchers() boundary = ErrorBoundary() boundary.setKey("client-key") boundary.urlString = wireMockRule.url("/v1/sdk_exception") @@ -26,7 +28,6 @@ class ErrorBoundaryTest { stubFor(post(urlMatching("/v1/sdk_exception")).willReturn(aResponse().withStatus(202))) app = mockk() - TestUtil.mockDispatchers() TestUtil.stubAppFunctions(app) val network = TestUtil.mockBrokenNetwork() Statsig.client = StatsigClient() @@ -44,7 +45,7 @@ class ErrorBoundaryTest { val wireMockRule = WireMockRule() @Test - fun testLoggingToEndpoint() { + fun testLoggingToEndpoint() = runBlockingTest { boundary.capture({ throw IOException("Test") }) @@ -70,11 +71,7 @@ class ErrorBoundaryTest { } @Test - fun testItDoesNotLogTheSameExceptionMultipleTimes() { - boundary.capture({ - throw IOException("Test") - }) - + fun testItDoesNotLogTheSameExceptionMultipleTimes() = runBlockingTest { boundary.capture({ throw IOException("Test") }) diff --git a/src/test/java/com/statsig/androidsdk/StatsigLongInitializationTimeoutTest.kt b/src/test/java/com/statsig/androidsdk/StatsigLongInitializationTimeoutTest.kt new file mode 100644 index 0000000..ef03bb0 --- /dev/null +++ b/src/test/java/com/statsig/androidsdk/StatsigLongInitializationTimeoutTest.kt @@ -0,0 +1,91 @@ +package com.statsig.androidsdk + +import android.app.Application +import io.mockk.every +import io.mockk.mockk +import io.mockk.spyk +import kotlinx.coroutines.* +import kotlinx.coroutines.test.runBlockingTest +import okhttp3.mockwebserver.Dispatcher +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import okhttp3.mockwebserver.RecordedRequest +import org.junit.After +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +class StatsigLongInitializationTimeoutTest { + + private var app: Application = mockk() + private lateinit var client: StatsigClient + private lateinit var errorBoundary: ErrorBoundary + private lateinit var mockWebServer: MockWebServer + private var initializeHits = 0 + + @Before + fun setup() { + mockWebServer = MockWebServer() + val dispatcher = object : Dispatcher() { + override fun dispatch(request: RecordedRequest): MockResponse { + return if (request.path!!.contains("initialize")) { + initializeHits++ + runBlocking { + delay(500) + } + MockResponse() + .setBody("{\"result\":\"error logged\"}") + .setResponseCode(503) + } else { + MockResponse().setResponseCode(404) + } + } + } + mockWebServer.dispatcher = dispatcher + mockWebServer.start() + client = spyk(StatsigClient(), recordPrivateCalls = true) + client.errorBoundary = spyk(client.errorBoundary) + errorBoundary = client.errorBoundary + + TestUtil.mockDispatchers() + TestUtil.stubAppFunctions(app) + + every { + errorBoundary.getUrl() + } returns mockWebServer.url("/v1/sdk_exception").toString() + + client.errorBoundary = errorBoundary + } + + @After + fun tearDown() { + mockWebServer.shutdown() + } + + @Test + fun testInitializeAsyncWithSlowErrorBoundary() = runBlockingTest { + var initTimeout = 10000L + val latch = CountDownLatch(1) + + client.initializeAsync( + app, + "client-key", + StatsigUser("test_user"), + object : IStatsigCallback { + override fun onStatsigInitialize(details: InitializationDetails) { + latch.countDown() + } + + override fun onStatsigUpdateUser() { + // no op + } + }, + StatsigOptions(initTimeoutMs = initTimeout, api = mockWebServer.url("/").toString()), + ) + latch.await(initTimeout, TimeUnit.SECONDS) + assert(client.isInitialized()) + assertTrue(initializeHits === 1) + } +} diff --git a/src/test/java/com/statsig/androidsdk/StatsigMultipleInitializeTest.kt b/src/test/java/com/statsig/androidsdk/StatsigMultipleInitializeTest.kt new file mode 100644 index 0000000..d1481d0 --- /dev/null +++ b/src/test/java/com/statsig/androidsdk/StatsigMultipleInitializeTest.kt @@ -0,0 +1,170 @@ +package com.statsig.androidsdk + +import android.app.Application +import io.mockk.* +import kotlinx.coroutines.* +import org.junit.Before +import org.junit.Test + +class StatsigMultipleInitializeTest { + + private lateinit var client: StatsigClient + private lateinit var app: Application + private lateinit var network: StatsigNetwork + + @Before + fun setup() { + TestUtil.mockDispatchers() + app = mockk(relaxed = true) + client = spyk(StatsigClient(), recordPrivateCalls = true) + network = TestUtil.mockNetwork() + client.statsigNetwork = network + + TestUtil.stubAppFunctions(app) + + coEvery { + network.initialize( + api = any(), + user = any(), + sinceTime = any(), + metadata = any(), + coroutineScope = any(), + context = any(), + diagnostics = any(), + hashUsed = any(), + previousDerivedFields = any(), + ) + } coAnswers { + TestUtil.makeInitializeResponse() + } + } + + @Test + fun testMultipleInitializeAsyncCalls() { + val job1 = GlobalScope.launch(Dispatchers.IO) { + client.initializeAsync(app, "client-key", StatsigUser("test_user")) + } + + val job2 = GlobalScope.launch(Dispatchers.IO) { + client.initializeAsync(app, "client-key", StatsigUser("test_user")) + } + + val job3 = GlobalScope.launch(Dispatchers.IO) { + client.initializeAsync(app, "client-key", StatsigUser("test_user")) + } + + runBlocking { + joinAll(job1, job2, job3) + } + coVerify(exactly = 1) { + network.initialize( + api = any(), + user = any(), + sinceTime = any(), + metadata = any(), + coroutineScope = any(), + context = any(), + diagnostics = any(), + hashUsed = any(), + previousDerivedFields = any(), + ) + } + } + + @Test + fun testMultipleInitializeCalls() { + val job1 = GlobalScope.launch(Dispatchers.IO) { + client.initialize(app, "client-key", StatsigUser("test_user")) + } + + val job2 = GlobalScope.launch(Dispatchers.IO) { + client.initialize(app, "client-key", StatsigUser("test_user")) + } + + val job3 = GlobalScope.launch(Dispatchers.IO) { + client.initialize(app, "client-key", StatsigUser("test_user")) + } + + runBlocking { + joinAll(job1, job2, job3) + } + coVerify(exactly = 1) { + network.initialize( + api = any(), + user = any(), + sinceTime = any(), + metadata = any(), + coroutineScope = any(), + context = any(), + diagnostics = any(), + hashUsed = any(), + previousDerivedFields = any(), + ) + } + } + + @Test + fun testMultipleInitializeCallsOnMain() { + val job1 = GlobalScope.launch(Dispatchers.Default) { + client.initialize(app, "client-key", StatsigUser("test_user")) + } + + val job2 = GlobalScope.launch(Dispatchers.Default) { + client.initialize(app, "client-key", StatsigUser("test_user")) + } + + val job3 = GlobalScope.launch(Dispatchers.Default) { + client.initialize(app, "client-key", StatsigUser("test_user")) + } + + runBlocking { + joinAll(job1, job2, job3) + } + + coVerify(exactly = 1) { + network.initialize( + api = any(), + user = any(), + sinceTime = any(), + metadata = any(), + coroutineScope = any(), + context = any(), + diagnostics = any(), + hashUsed = any(), + previousDerivedFields = any(), + ) + } + } + + @Test + fun testMultipleInitializeAsyncCallsOnMain() { + val job1 = GlobalScope.launch(Dispatchers.Default) { + client.initializeAsync(app, "client-key", StatsigUser("test_user")) + } + + val job2 = GlobalScope.launch(Dispatchers.Default) { + client.initializeAsync(app, "client-key", StatsigUser("test_user")) + } + + val job3 = GlobalScope.launch(Dispatchers.Default) { + client.initializeAsync(app, "client-key", StatsigUser("test_user")) + } + + runBlocking { + joinAll(job1, job2, job3) + } + coVerify(exactly = 1) { + network.initialize( + api = any(), + user = any(), + sinceTime = any(), + metadata = any(), + coroutineScope = any(), + context = any(), + diagnostics = any(), + hashUsed = any(), + previousDerivedFields = any(), + ) + } + } +}