diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java index 06a5de2f73..b909dc4bf0 100644 --- a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java @@ -100,6 +100,7 @@ import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.IPV4_PREFERRED; import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.IPV6_PREFERRED; import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.preferredAddressType; +import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.toRecordTypeNames; import static io.servicetalk.dns.discovery.netty.ServiceDiscovererUtils.calculateDifference; import static io.servicetalk.transport.netty.internal.BuilderUtils.datagramChannel; import static io.servicetalk.transport.netty.internal.BuilderUtils.socketChannel; @@ -113,6 +114,7 @@ import static java.util.Collections.singletonList; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.function.Function.identity; @@ -129,6 +131,7 @@ final class DefaultDnsClient implements DnsClient { private final MinTtlCache ttlCache; private final long maxTTLNanos; private final long ttlJitterNanos; + private final long resolutionTimeoutMillis; private final ListenableAsyncCloseable asyncCloseable; @Nullable private final DnsServiceDiscovererObserver observer; @@ -149,6 +152,7 @@ final class DefaultDnsClient implements DnsClient { Duration srvHostNameRepeatInitialDelay, Duration srvHostNameRepeatJitter, @Nullable Integer maxUdpPayloadSize, @Nullable final Integer ndots, @Nullable final Boolean optResourceEnabled, @Nullable final Duration queryTimeout, + @Nullable Duration resolutionTimeout, final DnsResolverAddressTypes dnsResolverAddressTypes, @Nullable final SocketAddress localAddress, @Nullable final DnsServerAddressStreamProvider dnsServerAddressStreamProvider, @@ -222,6 +226,10 @@ final class DefaultDnsClient implements DnsClient { builder.nameServerProvider(toNettyType(dnsServerAddressStreamProvider)); } resolver = builder.build(); + this.resolutionTimeoutMillis = resolutionTimeout != null ? resolutionTimeout.toMillis() : + // Default value is chosen based on a combination of default "timeout" and "attempts" options of + // /etc/resolv.conf: https://man7.org/linux/man-pages/man5/resolv.conf.5.html + resolver.queryTimeoutMillis() * 2; } @Override @@ -424,9 +432,21 @@ protected AbstractDnsSubscription newSubscription( return new AbstractDnsSubscription(subscriber) { @Override protected Future> doDnsQuery(final boolean scheduledQuery) { - Promise> promise = nettyIoExecutor.eventLoopGroup().next().newPromise(); - resolver.resolveAll(new DefaultDnsQuestion(name, SRV)) - .addListener((Future> completedFuture) -> { + final EventLoop eventLoop = nettyIoExecutor.eventLoopGroup().next(); + final Promise> promise = eventLoop.newPromise(); + final Future> resolveFuture = + resolver.resolveAll(new DefaultDnsQuestion(name, SRV)); + final Future timeoutFuture = resolutionTimeoutMillis == 0L ? null : eventLoop.schedule(() -> { + if (!promise.isDone() && promise.tryFailure(DnsNameResolverTimeoutException.newInstance( + name, resolutionTimeoutMillis, SRV.toString(), + SrvRecordPublisher.class, "doDnsQuery"))) { + resolveFuture.cancel(true); + } + }, resolutionTimeoutMillis, MILLISECONDS); + resolveFuture.addListener((Future> completedFuture) -> { + if (timeoutFuture != null) { + timeoutFuture.cancel(true); + } Throwable cause = completedFuture.cause(); if (cause != null) { promise.tryFailure(cause); @@ -501,9 +521,21 @@ protected Future> doDnsQuery(final boolean scheduledQuery if (scheduledQuery) { ttlCache.prepareForResolution(name); } - Promise> dnsAnswerPromise = - nettyIoExecutor.eventLoopGroup().next().newPromise(); - resolver.resolveAll(name).addListener(completedFuture -> { + final EventLoop eventLoop = nettyIoExecutor.eventLoopGroup().next(); + final Promise> dnsAnswerPromise = eventLoop.newPromise(); + final Future> resolveFuture = resolver.resolveAll(name); + final Future timeoutFuture = resolutionTimeoutMillis == 0L ? null : eventLoop.schedule(() -> { + if (!dnsAnswerPromise.isDone() && dnsAnswerPromise.tryFailure( + DnsNameResolverTimeoutException.newInstance(name, resolutionTimeoutMillis, + toRecordTypeNames(addressTypes), ARecordPublisher.class, "doDnsQuery"))) { + resolveFuture.cancel(true); + } + }, resolutionTimeoutMillis, MILLISECONDS); + + resolveFuture.addListener(completedFuture -> { + if (timeoutFuture != null) { + timeoutFuture.cancel(true); + } Throwable cause = completedFuture.cause(); if (cause != null) { dnsAnswerPromise.tryFailure(cause); diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsServiceDiscovererBuilder.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsServiceDiscovererBuilder.java index 6422fe225f..816a3050f1 100644 --- a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsServiceDiscovererBuilder.java +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsServiceDiscovererBuilder.java @@ -19,6 +19,7 @@ import io.servicetalk.client.api.ServiceDiscovererEvent; import io.servicetalk.transport.api.HostAndPort; import io.servicetalk.transport.api.IoExecutor; +import io.servicetalk.utils.internal.DurationUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -124,6 +125,8 @@ public final class DefaultDnsServiceDiscovererBuilder implements DnsServiceDisco private IoExecutor ioExecutor; @Nullable private Duration queryTimeout; + @Nullable + private Duration resolutionTimeout; private int consolidateCacheSize = DEFAULT_CONSOLIDATE_CACHE_SIZE; private int minTTLSeconds = DEFAULT_MIN_TTL_POLL_SECONDS; private int maxTTLSeconds = DEFAULT_MAX_TTL_POLL_SECONDS; @@ -258,8 +261,15 @@ public DefaultDnsServiceDiscovererBuilder ndots(final int ndots) { } @Override - public DefaultDnsServiceDiscovererBuilder queryTimeout(final Duration queryTimeout) { - this.queryTimeout = queryTimeout; + public DefaultDnsServiceDiscovererBuilder queryTimeout(final @Nullable Duration queryTimeout) { + this.queryTimeout = queryTimeout == null ? null : DurationUtils.ensureNonNegative(queryTimeout, "queryTimeout"); + return this; + } + + @Override + public DefaultDnsServiceDiscovererBuilder resolutionTimeout(final @Nullable Duration resolutionTimeout) { + this.resolutionTimeout = resolutionTimeout == null ? null : + DurationUtils.ensureNonNegative(resolutionTimeout, "resolutionTimeout"); return this; } @@ -267,7 +277,7 @@ public DefaultDnsServiceDiscovererBuilder queryTimeout(final Duration queryTimeo public DefaultDnsServiceDiscovererBuilder dnsResolverAddressTypes( @Nullable final DnsResolverAddressTypes dnsResolverAddressTypes) { this.dnsResolverAddressTypes = dnsResolverAddressTypes != null ? dnsResolverAddressTypes : - systemDefault(); + DEFAULT_DNS_RESOLVER_ADDRESS_TYPES; return this; } @@ -385,8 +395,8 @@ DnsClient build() { ttlJitter.toNanos(), srvConcurrency, completeOncePreferredResolved, srvFilterDuplicateEvents, srvHostNameRepeatInitialDelay, srvHostNameRepeatJitter, maxUdpPayloadSize, ndots, optResourceEnabled, - queryTimeout, dnsResolverAddressTypes, localAddress, dnsServerAddressStreamProvider, observer, - missingRecordStatus, nxInvalidation); + queryTimeout, resolutionTimeout, dnsResolverAddressTypes, localAddress, dnsServerAddressStreamProvider, + observer, missingRecordStatus, nxInvalidation); return filterFactory == null ? rawClient : filterFactory.create(rawClient); } diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DelegatingDnsServiceDiscovererBuilder.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DelegatingDnsServiceDiscovererBuilder.java index 6738f66514..41deea7049 100644 --- a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DelegatingDnsServiceDiscovererBuilder.java +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DelegatingDnsServiceDiscovererBuilder.java @@ -117,11 +117,17 @@ public DnsServiceDiscovererBuilder ndots(final int ndots) { } @Override - public DnsServiceDiscovererBuilder queryTimeout(final Duration queryTimeout) { + public DnsServiceDiscovererBuilder queryTimeout(final @Nullable Duration queryTimeout) { delegate = delegate.queryTimeout(queryTimeout); return this; } + @Override + public DnsServiceDiscovererBuilder resolutionTimeout(final @Nullable Duration resolutionTimeout) { + delegate = delegate.resolutionTimeout(resolutionTimeout); + return this; + } + @Override public DnsServiceDiscovererBuilder dnsResolverAddressTypes( @Nullable final DnsResolverAddressTypes dnsResolverAddressTypes) { diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsNameResolverTimeoutException.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsNameResolverTimeoutException.java new file mode 100644 index 0000000000..56ebfad3a0 --- /dev/null +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsNameResolverTimeoutException.java @@ -0,0 +1,39 @@ +/* + * Copyright © 2024 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.dns.discovery.netty; + +import io.servicetalk.concurrent.internal.ThrowableUtils; + +import java.net.UnknownHostException; + +final class DnsNameResolverTimeoutException extends UnknownHostException { + private static final long serialVersionUID = 3089160074512305891L; + + private DnsNameResolverTimeoutException(final String name, final String recordType, final long timeoutMs) { + super("Resolution for '" + name + "' [" + recordType + "] timed out after " + timeoutMs + " milliseconds"); + } + + @Override + public Throwable fillInStackTrace() { + return this; + } + + static DnsNameResolverTimeoutException newInstance(final String name, final long timeoutMs, final String recordType, + final Class clazz, final String method) { + return ThrowableUtils.unknownStackTrace(new DnsNameResolverTimeoutException(name, recordType, timeoutMs), + clazz, method); + } +} diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsResolverAddressTypes.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsResolverAddressTypes.java index 51bc8db36d..4764108dd3 100644 --- a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsResolverAddressTypes.java +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsResolverAddressTypes.java @@ -16,6 +16,7 @@ package io.servicetalk.dns.discovery.netty; import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.handler.codec.dns.DnsRecordType; import io.netty.resolver.ResolvedAddressTypes; import io.netty.resolver.dns.DnsNameResolverBuilder; @@ -52,6 +53,9 @@ public enum DnsResolverAddressTypes { */ IPV6_PREFERRED_RETURN_ALL; + private static final String A_AAAA_STRING = DnsRecordType.A + ", " + DnsRecordType.AAAA; + private static final String AAAA_A_STRING = DnsRecordType.AAAA + ", " + DnsRecordType.A; + /** * The default value, based on "java.net" system properties: {@code java.net.preferIPv4Stack} and * {@code java.net.preferIPv6Stack}. @@ -109,4 +113,22 @@ static InternetProtocolFamily preferredAddressType(ResolvedAddressTypes resolved ": " + resolvedAddressTypes); } } + + static String toRecordTypeNames(DnsResolverAddressTypes dnsResolverAddressType) { + switch (dnsResolverAddressType) { + case IPV4_ONLY: + return DnsRecordType.A.toString(); + case IPV6_ONLY: + return DnsRecordType.AAAA.toString(); + case IPV4_PREFERRED: + case IPV4_PREFERRED_RETURN_ALL: + return A_AAAA_STRING; + case IPV6_PREFERRED: + case IPV6_PREFERRED_RETURN_ALL: + return AAAA_A_STRING; + default: + throw new IllegalArgumentException("Unknown value for " + DnsResolverAddressTypes.class.getName() + + ": " + dnsResolverAddressType); + } + } } diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsServiceDiscovererBuilder.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsServiceDiscovererBuilder.java index 5e505293ca..74f80111a7 100644 --- a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsServiceDiscovererBuilder.java +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DnsServiceDiscovererBuilder.java @@ -41,8 +41,8 @@ public interface DnsServiceDiscovererBuilder { * @return {@code this}. */ default DnsServiceDiscovererBuilder consolidateCacheSize(int consolidateCacheSize) { - throw new UnsupportedOperationException("DnsServiceDiscovererBuilder#consolidateCacheSize(int) is not " + - "supported by " + getClass()); + throw new UnsupportedOperationException( + "DnsServiceDiscovererBuilder#consolidateCacheSize(int) is not supported by " + getClass()); } /** @@ -123,8 +123,8 @@ default DnsServiceDiscovererBuilder consolidateCacheSize(int consolidateCacheSiz */ default DnsServiceDiscovererBuilder ttl(int minSeconds, int maxSeconds, int minCacheSeconds, int maxCacheSeconds, int negativeCacheSeconds) { - throw new UnsupportedOperationException("DnsServiceDiscovererBuilder#ttl(int, int, int, int, int) is not " + - "supported by " + getClass()); + throw new UnsupportedOperationException( + "DnsServiceDiscovererBuilder#ttl(int, int, int, int, int) is not supported by " + getClass()); } /** @@ -146,8 +146,8 @@ default DnsServiceDiscovererBuilder ttl(int minSeconds, int maxSeconds, int minC * @return {@code this}. */ default DnsServiceDiscovererBuilder localAddress(@Nullable SocketAddress localAddress) { - throw new UnsupportedOperationException("DnsServiceDiscovererBuilder#localAddress(SocketAddress) is not " + - "supported by " + getClass()); + throw new UnsupportedOperationException( + "DnsServiceDiscovererBuilder#localAddress(SocketAddress) is not supported by " + getClass()); } /** @@ -182,6 +182,8 @@ DnsServiceDiscovererBuilder dnsServerAddressStreamProvider( /** * Set the number of dots which must appear in a name before an initial absolute query is made. + *

+ * If not set, the default value is read from {@code ndots} option of {@code /etc/resolv.conf}). * * @param ndots the ndots value. * @return {@code this}. @@ -189,12 +191,37 @@ DnsServiceDiscovererBuilder dnsServerAddressStreamProvider( DnsServiceDiscovererBuilder ndots(int ndots); /** - * Sets the timeout of each DNS query performed by this service discoverer. + * Sets the timeout of each DNS query performed by this service discoverer as part of a resolution request. + *

+ * Zero ({@code 0}) disables the timeout. If not set, the default value is read from {@code timeout} option of + * {@code /etc/resolv.conf}). Similar to linux systems, this value may be silently capped. * * @param queryTimeout the query timeout value - * @return {@code this}. + * @return {@code this} + * @see #resolutionTimeout(Duration) + */ + DnsServiceDiscovererBuilder queryTimeout(@Nullable Duration queryTimeout); + + /** + * Sets the total timeout of each DNS resolution performed by this service discoverer. + *

+ * Each resolution may execute one or more DNS queries, like following multiple CNAME(s) or trying different search + * domains. This is the total timeout for all intermediate queries involved in a single resolution request. Note, + * that SRV resolutions may generate independent resolutions for + * {@code A/AAAA} records. In this case, this timeout will be applied to an {@code SRV} resolution and each + * {@code A/AAAA} resolution independently. + *

+ * Zero ({@code 0}) disables the timeout. If not set, it defaults to {@link #queryTimeout(Duration) query timeout} + * value multiplied by {@code 2}. + * + * @param resolutionTimeout the query timeout value + * @return {@code this} + * @see #queryTimeout(Duration) */ - DnsServiceDiscovererBuilder queryTimeout(Duration queryTimeout); + default DnsServiceDiscovererBuilder resolutionTimeout(@Nullable Duration resolutionTimeout) { + throw new UnsupportedOperationException( + "DnsServiceDiscovererBuilder#resolutionTimeout(Duration) is not supported by " + getClass()); + } /** * Sets the list of the protocol families of the address resolved. diff --git a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java index 8b50d95d28..b6ef129267 100644 --- a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java +++ b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java @@ -29,6 +29,7 @@ import io.servicetalk.concurrent.api.TestExecutor; import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; import io.servicetalk.transport.netty.internal.EventLoopAwareNettyIoExecutor; +import io.servicetalk.utils.internal.DurationUtils; import io.netty.channel.EventLoopGroup; import org.apache.directory.server.dns.messages.RecordType; @@ -67,6 +68,7 @@ import static io.servicetalk.client.api.ServiceDiscovererEvent.Status.EXPIRED; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static io.servicetalk.concurrent.internal.TestTimeoutConstants.CI; import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.IPV4_ONLY; import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.IPV4_PREFERRED; import static io.servicetalk.dns.discovery.netty.DnsResolverAddressTypes.IPV4_PREFERRED_RETURN_ALL; @@ -85,7 +87,10 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.function.Function.identity; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -99,6 +104,7 @@ class DefaultDnsClientTest { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultDnsClientTest.class); private static final int DEFAULT_TTL = 1; + private static final Duration DEFAULT_TIMEOUT = ofMillis(CI ? 500 : 100); @RegisterExtension static final ExecutorExtension timerExecutor = ExecutorExtension.withTestExecutor() @@ -1208,6 +1214,74 @@ void cacheForARecord(boolean resubscribe) throws Exception { assertThat(subscriber.pollOnNext(50, MILLISECONDS), is(nullValue())); } + @ParameterizedTest(name = "{displayName} [{index}] recordType={0}") + @EnumSource(value = RecordType.class, names = {"A", "AAAA", "SRV"}) + void testQueryTimeout(RecordType recordType) throws Exception { + testTimeout(DEFAULT_TIMEOUT, Duration.ZERO, recordType); + } + + @ParameterizedTest(name = "{displayName} [{index}] recordType={0}") + @EnumSource(value = RecordType.class, names = {"A", "AAAA", "SRV"}) + void testResolutionTimeout(RecordType recordType) throws Exception { + testTimeout(Duration.ZERO, DEFAULT_TIMEOUT, recordType); + } + + void testTimeout(Duration queryTimeout, Duration resolutionTimeout, RecordType recordType) throws Exception { + setup(builder -> builder + .queryTimeout(queryTimeout) + .resolutionTimeout(resolutionTimeout) + .dnsResolverAddressTypes(recordType == RecordType.AAAA ? IPV6_PREFERRED : IPV4_PREFERRED)); + String srvDomain = "srv.apple.com"; + String aDomain = "a.apple.com"; + String aaaaDomain = "aaaa.apple.com"; + String ipv4 = nextIp(); + String ipv6 = nextIp6(); + recordStore.addSrv(srvDomain, aDomain, 80, DEFAULT_TTL); + recordStore.addSrv(srvDomain, aaaaDomain, 80, DEFAULT_TTL); + recordStore.addIPv4Address(aDomain, DEFAULT_TTL, ipv4); + recordStore.addIPv6Address(aaaaDomain, DEFAULT_TTL, ipv6); + + String domain; + try { + TestPublisherSubscriber subscriber; + switch (recordType) { + case A: + domain = aDomain; + recordStore.addTimeout(aDomain, RecordType.A); + subscriber = dnsQuery(aDomain); + break; + case AAAA: + domain = aaaaDomain; + recordStore.addTimeout(aaaaDomain, RecordType.AAAA); + subscriber = dnsQuery(aaaaDomain); + break; + case SRV: + domain = srvDomain; + recordStore.addTimeout(srvDomain, RecordType.SRV); + subscriber = dnsSrvQuery(srvDomain); + break; + default: + throw new IllegalArgumentException("Unknown RecordType: " + recordType); + } + Subscription subscription = subscriber.awaitSubscription(); + long startTime = System.nanoTime(); + subscription.request(1); + Throwable error = subscriber.awaitOnError(); + assertThat(error, instanceOf(UnknownHostException.class)); + assertThat(error.getMessage(), allOf(containsString(domain), containsString(recordType.name()))); + if (DurationUtils.isPositive(queryTimeout)) { + assertThat(error.getCause(), instanceOf(io.netty.resolver.dns.DnsNameResolverTimeoutException.class)); + assertThat(error.getCause().getMessage(), + allOf(containsString(domain), containsString(Long.toString(DEFAULT_TIMEOUT.toMillis())))); + } + assertThat(System.nanoTime() - startTime, greaterThanOrEqualTo(DEFAULT_TIMEOUT.toNanos())); + } finally { + recordStore.removeTimeout(aDomain, RecordType.A); + recordStore.removeTimeout(aaaaDomain, RecordType.AAAA); + recordStore.removeTimeout(srvDomain, RecordType.SRV); + } + } + private static Subscriber> mockThrowSubscriber( CountDownLatch latchOnError, Queue> queue) { @SuppressWarnings("unchecked") @@ -1233,7 +1307,13 @@ private static Subscriber> mockThrowSubscriber( } private TestPublisherSubscriber> dnsSrvQuery(String domain) { + return dnsSrvQuery(domain, (i, t) -> Completable.failed(t)); + } + + private TestPublisherSubscriber> dnsSrvQuery(String domain, + BiIntFunction retryStrategy) { Publisher> publisher = client.dnsSrvQuery(domain) + .retryWhen(retryStrategy) .flatMapConcatIterable(identity()); TestPublisherSubscriber> subscriber = new TestPublisherSubscriber<>(); @@ -1242,16 +1322,10 @@ private TestPublisherSubscriber> dnsSr } private TestPublisherSubscriber> dnsSrvQueryWithInfRetry(String domain) { - Publisher> publisher = client.dnsSrvQuery(domain) - .retry((__, err) -> { - LOGGER.error("Retrying error ", err); - return true; - }) - .flatMapConcatIterable(identity()); - TestPublisherSubscriber> subscriber = - new TestPublisherSubscriber<>(); - toSource(publisher).subscribe(subscriber); - return subscriber; + return dnsSrvQuery(domain, (__, err) -> { + LOGGER.error("Retrying error ", err); + return Completable.completed(); + }); } private TestPublisherSubscriber> dnsQuery(String domain) { diff --git a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java index 0cdff5c0c6..50b797ee15 100644 --- a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java +++ b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java @@ -34,6 +34,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; import javax.annotation.Nullable; import static java.util.Collections.emptySet; @@ -47,6 +48,9 @@ final class TestRecordStore implements RecordStore { private static final Logger LOGGER = LoggerFactory.getLogger(TestRecordStore.class); private static final int SRV_DEFAULT_WEIGHT = 10; private static final int SRV_DEFAULT_PRIORITY = 10; + + private final Set failSet = new HashSet<>(); + private final Map timeouts = new ConcurrentHashMap<>(); private final Map>> recordsToReturnByDomain = new ConcurrentHashMap<>(); @@ -89,8 +93,6 @@ public int hashCode() { } } - private final Set failSet = new HashSet<>(); - public synchronized void addFail(final ServFail fail) { failSet.add(fail); } @@ -99,6 +101,17 @@ public synchronized void removeFail(final ServFail fail) { failSet.remove(fail); } + public void addTimeout(final String domain, final RecordType recordType) { + timeouts.put(new QuestionRecord(domain, recordType, RecordClass.IN), new CountDownLatch(1)); + } + + public void removeTimeout(final String domain, final RecordType recordType) { + CountDownLatch latch = timeouts.remove(new QuestionRecord(domain, recordType, RecordClass.IN)); + if (latch != null) { + latch.countDown(); + } + } + public synchronized void addSrv(final String domain, String targetDomain, final int port, final int ttl) { addSrv(domain, targetDomain, port, ttl, SRV_DEFAULT_WEIGHT, SRV_DEFAULT_PRIORITY); } @@ -218,9 +231,19 @@ private boolean removeRecords(ResourceRecord rr, List recordList return removed; } - @Nullable @Override public synchronized Set getRecords(final QuestionRecord questionRecord) throws DnsException { + final CountDownLatch timeoutLatch = timeouts.get(questionRecord); + if (timeoutLatch != null && timeoutLatch.getCount() > 0) { + LOGGER.debug("Holding a thread to generate a timeout for {}", questionRecord); + try { + timeoutLatch.await(); + } catch (InterruptedException e) { + DnsException dnsException = new DnsException(SERVER_FAILURE); + dnsException.initCause(e); + throw dnsException; + } + } final String domain = questionRecord.getDomainName(); if (failSet.contains(ServFail.of(questionRecord))) { throw new DnsException(SERVER_FAILURE);