Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronize token refresh #1020

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
Expand Down Expand Up @@ -95,8 +94,7 @@ public class MqttPublisher implements MessagePublisher {
private static final int INITIALIZE_TIME_MS = 20000;
private static final String BROKER_URL_FORMAT = "%s://%s:%s";
private static final int PUBLISH_THREAD_COUNT = 10;
private static final int TOKEN_EXPIRATION_SEC = 60 * 60;
private static final int TOKEN_EXPIRATION_MS = TOKEN_EXPIRATION_SEC * 1000;
private static final Duration TOKEN_EXPIRATION = Duration.ofHours(1);
private static final String TICKLE_TOPIC = "events/udmi";
private static final long TICKLE_PERIOD_SEC = 10;
private static final String REFLECTOR_PUBLIC_KEY = "reflector/rsa_public.pem";
Expand Down Expand Up @@ -131,6 +129,7 @@ public class MqttPublisher implements MessagePublisher {
private final Envelope savedState = new Envelope();
private final AtomicInteger publisherQueueSize = new AtomicInteger();
private final AtomicInteger publishCount = new AtomicInteger();
private final String mqttClientId;
private long mqttTokenSetTimeMs;
private MqttConnectOptions mqttConnectOptions;
private boolean shutdown;
Expand All @@ -150,9 +149,10 @@ public class MqttPublisher implements MessagePublisher {
providerHostname = getProviderHostname(config);
topicBase = getTopicBase();
clientId = catchToNull(() -> config.reflector_endpoint.client_id);
LOG.info(deviceId + " token expiration sec " + TOKEN_EXPIRATION_SEC);
LOG.info(deviceId + " token expiration sec " + TOKEN_EXPIRATION.getSeconds());
certManager = getCertManager();
mqttClient = newMqttClient(deviceId);
mqttClientId = mqttClient.getClientId();
connectMqttClient(deviceId);
tickler = scheduleTickler();
}
Expand Down Expand Up @@ -251,7 +251,7 @@ private ScheduledFuture<?> scheduleTickler() {
}

private void tickleConnection() {
LOG.debug("Tickle " + mqttClient.getClientId());
LOG.debug("Tickle " + mqttClientId);
if (shutdown) {
try {
LOG.info("Tickler closing connection due to shutdown request");
Expand Down Expand Up @@ -308,7 +308,8 @@ private synchronized void publishCore(String deviceId, String topic, String payl
publishRaw(deviceId, topic, payload, start);
}

private void publishRaw(String deviceId, String topic, String payload, Instant start) {
private synchronized void publishRaw(String deviceId, String topic, String payload,
Instant start) {
try {
publisherQueueSize.decrementAndGet();
if (!connectWait.tryAcquire(INITIALIZE_TIME_MS, TimeUnit.MILLISECONDS)) {
Expand Down Expand Up @@ -362,14 +363,14 @@ private synchronized void delayStateUpdate(String deviceId) {
lastStateTime.put(deviceId, now);
}

private void sendMessage(String mqttTopic, byte[] mqttMessage) throws Exception {
private synchronized void sendMessage(String mqttTopic, byte[] mqttMessage) throws Exception {
LOG.debug(deviceId + " sending message to " + mqttTopic);
mqttClient.publish(mqttTopic, mqttMessage, QOS_AT_LEAST_ONCE, MQTT_NO_RETAIN);
publishCounter.incrementAndGet();
}

@Override
public void close() {
public synchronized void close() {
try {
LOG.debug(format("Shutting down executor %x", publisherExecutor.hashCode()));
ifNotNullThen(tickler, () -> tickler.cancel(false));
Expand All @@ -388,15 +389,15 @@ public void close() {

@Override
public String getSubscriptionId() {
return mqttClient.getClientId();
return mqttClientId;
}

@Override
public void activate() {
}

@Override
public boolean isActive() {
public synchronized boolean isActive() {
return mqttClient.isConnected();
}

Expand Down Expand Up @@ -430,7 +431,7 @@ private MqttClient newMqttClient(String deviceId) {
}
}

private void connectMqttClient(String deviceId) {
private synchronized void connectMqttClient(String deviceId) {
try {
if (mqttClient.isConnected()) {
return;
Expand Down Expand Up @@ -465,7 +466,7 @@ private String getUserName() {
};
}

private void connectAndSetupMqtt() {
private synchronized void connectAndSetupMqtt() {
try {
LOG.info(deviceId + " creating new auth token for audience " + projectId);
mqttConnectOptions.setPassword(getAuthToken(projectId));
Expand Down Expand Up @@ -497,8 +498,8 @@ private char[] getHashPassword(String audience) {
return hashKeyPassword.toCharArray();
}

private void maybeRefreshJwt() {
long refreshTime = mqttTokenSetTimeMs + TOKEN_EXPIRATION_MS / 2;
private synchronized void maybeRefreshJwt() {
long refreshTime = mqttTokenSetTimeMs + TOKEN_EXPIRATION.toMillis() / 2;
long currentTimeMillis = System.currentTimeMillis();
long remaining = refreshTime - currentTimeMillis;
LOG.debug(deviceId + " remaining until refresh " + remaining);
Expand Down Expand Up @@ -552,7 +553,7 @@ private void subscribeToConfig(String deviceId) {
clientSubscribe(CONFIG_TOPIC, QOS_AT_LEAST_ONCE);
}

private void clientSubscribe(String topicSuffix, int qos) {
private synchronized void clientSubscribe(String topicSuffix, int qos) {
String topic = topicBase + topicSuffix;
try {
LOG.info(format("Subscribing with qos %d to topic %s", qos, topic));
Expand Down Expand Up @@ -605,7 +606,7 @@ private String createJwt(String projectId, byte[] privateKeyBytes, String algori
JwtBuilder jwtBuilder =
Jwts.builder()
.setIssuedAt(now.toDate())
.setExpiration(now.plusMillis(TOKEN_EXPIRATION_MS).toDate())
.setExpiration(now.plusMillis((int) TOKEN_EXPIRATION.toMillis()).toDate())
.setAudience(projectId);

LOG.info(format("Creating jwt %s key with audience %s", algorithm, projectId));
Expand Down
Loading