From f55cf64923b369e0061f80a54328e97cf8448f14 Mon Sep 17 00:00:00 2001
From: Andy Grundman <andy@hybridized.org>
Date: Fri, 10 Jan 2025 04:35:15 -0500
Subject: [PATCH] fix(rtp) Use more accurate audio timestamp on Windows.
 Refactor audio queue to support passing capture_timestamp.

---
 src/audio.cpp                    | 20 ++++++++------
 src/audio.h                      | 26 +++++++++++++++++-
 src/platform/common.h            |  2 +-
 src/platform/linux/audio.cpp     |  5 ++--
 src/platform/macos/microphone.mm |  3 ++-
 src/platform/windows/audio.cpp   | 43 ++++++++++++++++++++++++++++--
 src/platform/windows/misc.h      | 11 ++++++++
 src/stream.cpp                   | 45 +++++++++++++++++++-------------
 tests/unit/test_audio.cpp        |  2 +-
 9 files changed, 123 insertions(+), 34 deletions(-)

diff --git a/src/audio.cpp b/src/audio.cpp
index 82b1ec37351..019eced7191 100644
--- a/src/audio.cpp
+++ b/src/audio.cpp
@@ -18,7 +18,7 @@
 namespace audio {
   using namespace std::literals;
   using opus_t = util::safe_ptr<OpusMSEncoder, opus_multistream_encoder_destroy>;
-  using sample_queue_t = std::shared_ptr<safe::queue_t<std::vector<float>>>;
+  using sample_queue_t = std::shared_ptr<safe::queue_t<audio_with_timestamp_t>>;
 
   static int
   start_audio_control(audio_ctx_t &ctx);
@@ -114,9 +114,9 @@ namespace audio {
 
     auto frame_size = config.packetDuration * stream.sampleRate / 1000;
     while (auto sample = samples->pop()) {
-      buffer_t packet { 1400 };
+      buffer_t packet_data { 1400 };
 
-      int bytes = opus_multistream_encode_float(opus.get(), sample->data(), frame_size, std::begin(packet), packet.size());
+      int bytes = opus_multistream_encode_float(opus.get(), sample->pcm.data(), frame_size, std::begin(packet_data), packet_data.size());
       if (bytes < 0) {
         BOOST_LOG(error) << "Couldn't encode audio: "sv << opus_strerror(bytes);
         packets->stop();
@@ -124,8 +124,12 @@ namespace audio {
         return;
       }
 
-      packet.fake_resize(bytes);
-      packets->raise(channel_data, std::move(packet));
+      packet_data.fake_resize(bytes);
+
+      auto packet = std::make_unique<packet_raw_t>(std::move(packet_data));
+      packet->channel_data = channel_data;
+      packet->capture_timestamp = sample->capture_timestamp;
+      packets->raise(std::move(packet));
     }
   }
 
@@ -216,10 +220,10 @@ namespace audio {
     int samples_per_frame = frame_size * stream.channelCount;
 
     while (!shutdown_event->peek()) {
-      std::vector<float> sample_buffer;
-      sample_buffer.resize(samples_per_frame);
+      audio_with_timestamp_t sample_buffer;
+      sample_buffer.pcm.resize(samples_per_frame);
 
-      auto status = mic->sample(sample_buffer);
+      auto status = mic->sample(sample_buffer.pcm, sample_buffer.capture_timestamp);
       switch (status) {
         case platf::capture_e::ok:
           break;
diff --git a/src/audio.h b/src/audio.h
index 927dfdef20b..c9e12a18946 100644
--- a/src/audio.h
+++ b/src/audio.h
@@ -68,7 +68,31 @@ namespace audio {
   };
 
   using buffer_t = util::buffer_t<std::uint8_t>;
-  using packet_t = std::pair<void *, buffer_t>;
+
+  struct packet_raw_t {
+    virtual ~packet_raw_t() = default;
+
+    packet_raw_t(buffer_t &&packet_data):
+        packet_data { std::move(packet_data) }
+    { }
+
+    size_t
+    data_size() {
+      return packet_data.size();
+    }
+
+    buffer_t packet_data;
+    void *channel_data = nullptr;
+    std::chrono::steady_clock::time_point capture_timestamp;
+  };
+
+  using packet_t = std::unique_ptr<packet_raw_t>;
+
+  struct audio_with_timestamp_t {
+    std::vector<float> pcm;
+    std::chrono::steady_clock::time_point capture_timestamp;
+  };
+
   using audio_ctx_ref_t = safe::shared_t<audio_ctx_t>::ptr_t;
 
   void
diff --git a/src/platform/common.h b/src/platform/common.h
index abcbefc82d8..4bfec582e3e 100644
--- a/src/platform/common.h
+++ b/src/platform/common.h
@@ -537,7 +537,7 @@ namespace platf {
   class mic_t {
   public:
     virtual capture_e
-    sample(std::vector<float> &frame_buffer) = 0;
+    sample(std::vector<float> &frame_buffer, std::chrono::steady_clock::time_point &capture_timestamp_out) = 0;
 
     virtual ~mic_t() = default;
   };
diff --git a/src/platform/linux/audio.cpp b/src/platform/linux/audio.cpp
index a48ee2f028d..ad987fc9805 100644
--- a/src/platform/linux/audio.cpp
+++ b/src/platform/linux/audio.cpp
@@ -54,8 +54,9 @@ namespace platf {
     util::safe_ptr<pa_simple, pa_simple_free> mic;
 
     capture_e
-    sample(std::vector<float> &sample_buf) override {
+    sample(std::vector<float> &sample_buf, std::chrono::steady_clock::time_point &capture_timestamp_out) override {
       auto sample_size = sample_buf.size();
+      capture_timestamp_out = std::chrono::steady_clock::now();
 
       auto buf = sample_buf.data();
       int status;
@@ -535,4 +536,4 @@ namespace platf {
 
     return audio;
   }
-}  // namespace platf
\ No newline at end of file
+}  // namespace platf
diff --git a/src/platform/macos/microphone.mm b/src/platform/macos/microphone.mm
index 8d2129f28b3..5bfe82beedf 100644
--- a/src/platform/macos/microphone.mm
+++ b/src/platform/macos/microphone.mm
@@ -19,8 +19,9 @@
     }
 
     capture_e
-    sample(std::vector<float> &sample_in) override {
+    sample(std::vector<float> &sample_in, std::chrono::steady_clock::time_point &capture_timestamp_out) override {
       auto sample_size = sample_in.size();
+      capture_timestamp_out = std::chrono::steady_clock::now();
 
       uint32_t length = 0;
       void *byteSampleBuffer = TPCircularBufferTail(&av_audio_capture->audioSampleBuffer, &length);
diff --git a/src/platform/windows/audio.cpp b/src/platform/windows/audio.cpp
index 3c401976afc..389c13de54a 100644
--- a/src/platform/windows/audio.cpp
+++ b/src/platform/windows/audio.cpp
@@ -417,7 +417,7 @@ namespace platf::audio {
   class mic_wasapi_t: public mic_t {
   public:
     capture_e
-    sample(std::vector<float> &sample_out) override {
+    sample(std::vector<float> &sample_out, std::chrono::steady_clock::time_point &capture_timestamp_out) override {
       auto sample_size = sample_out.size();
 
       // Refill the sample buffer if needed
@@ -428,6 +428,8 @@ namespace platf::audio {
         }
       }
 
+      capture_timestamp_out = capture_timestamp;
+
       // Fill the output buffer with samples
       std::copy_n(std::begin(sample_buf), sample_size, std::begin(sample_out));
 
@@ -499,6 +501,8 @@ namespace platf::audio {
       REFERENCE_TIME default_latency;
       audio_client->GetDevicePeriod(&default_latency, nullptr);
       default_latency_ms = default_latency / 1000;
+      // XXX the above is actually wrong because REFERENCE_TIME is in 100ns units,
+      // but I dont want to fix it for no reason. The correct millisecond conversion is to divide by 10000.
 
       std::uint32_t frames;
       status = audio_client->GetBufferSize(&frames);
@@ -541,6 +545,8 @@ namespace platf::audio {
         return -1;
       }
 
+      qpc_status = QPC_PENDING;
+
       return 0;
     }
 
@@ -572,6 +578,7 @@ namespace platf::audio {
       // number of samples / number of channels
       struct block_aligned_t {
         std::uint32_t audio_sample_size;
+        std::uint64_t capture_ts_100ns;
       } block_aligned;
 
       // Check if the default audio device has changed
@@ -606,7 +613,10 @@ namespace platf::audio {
           (BYTE **) &sample_aligned.samples,
           &block_aligned.audio_sample_size,
           &buffer_flags,
-          nullptr, nullptr);
+          nullptr,
+          &block_aligned.capture_ts_100ns);
+
+        auto capture_timestamp_fallback = std::chrono::steady_clock::now();
 
         switch (status) {
           case S_OK:
@@ -622,6 +632,10 @@ namespace platf::audio {
           BOOST_LOG(debug) << "Audio capture signaled buffer discontinuity";
         }
 
+        if (buffer_flags & AUDCLNT_BUFFERFLAGS_TIMESTAMP_ERROR) {
+          BOOST_LOG(warning) << "Audio capture signaled AUDCLNT_BUFFERFLAGS_TIMESTAMP_ERROR";
+        }
+
         sample_aligned.uninitialized = std::end(sample_buf) - sample_buf_pos;
         auto n = std::min(sample_aligned.uninitialized, block_aligned.audio_sample_size * channels);
 
@@ -638,6 +652,29 @@ namespace platf::audio {
 
         sample_buf_pos += n;
 
+        // When beginning capture, check that the QPC timestasmps from GetBuffer() are using the
+        // same clock. If the offset is too large, we fallback to fudging the timestamps
+        if (qpc_status == QPC_PENDING) {
+          auto qpc_capture_timestamp = std::chrono::steady_clock::time_point{std::chrono::microseconds{block_aligned.capture_ts_100ns / 10}};
+          auto qpc_offset_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
+            capture_timestamp_fallback - qpc_capture_timestamp
+          ).count();
+
+          // Expected value for qpc_offset_ms should be around -10ms, where 10 is the device's buffer size
+          if (abs(qpc_offset_ms) < MAX_QPC_TIMESTAMP_OFFSET_MS) {
+            qpc_status = QPC_VALID;
+            BOOST_LOG(info) << "Audio supports accurate timestamps. Offset (ms): " << qpc_offset_ms;
+          }
+          else {
+            qpc_status = QPC_INVALID;
+            BOOST_LOG(info) << "Audio timestamps out of range, accurate timestamps are disabled. Offset (ms): " << qpc_offset_ms;
+          }
+        }
+
+        capture_timestamp = (qpc_status == QPC_VALID)
+          ? std::chrono::steady_clock::time_point{std::chrono::microseconds{block_aligned.capture_ts_100ns / 10}}
+          : capture_timestamp_fallback; // std::chrono::steady_clock::now()
+
         audio_capture->ReleaseBuffer(block_aligned.audio_sample_size);
       }
 
@@ -668,6 +705,8 @@ namespace platf::audio {
     util::buffer_t<float> sample_buf;
     float *sample_buf_pos;
     int channels;
+    std::chrono::steady_clock::time_point capture_timestamp;
+    qpc_status_t qpc_status;
 
     HANDLE mmcss_task_handle = NULL;
   };
diff --git a/src/platform/windows/misc.h b/src/platform/windows/misc.h
index b045104f57b..f561c058abe 100644
--- a/src/platform/windows/misc.h
+++ b/src/platform/windows/misc.h
@@ -9,6 +9,17 @@
 #include <windows.h>
 #include <winnt.h>
 
+// Windows provides a timestamp from GetBuffer() indicating exactly when audio was captured.
+// Before trusting this timestamp, we check if it is compatible with our other time code.
+
+#define MAX_QPC_TIMESTAMP_OFFSET_MS 50  ///< QPC is allowed to be +/- this many milliseconds from now().
+
+enum qpc_status_t : int {
+  QPC_PENDING,   ///< QPC offset will be checked after capturing the first audio packet
+  QPC_INVALID,   ///< QPC offset exceeded MAX_QPC_TIMESTAMP_OFFSET_MS and we will generate timestamps
+  QPC_VALID      ///< QPC offset fell within acceptable range and will be used
+};
+
 namespace platf {
   void
   print_status(const std::string_view &prefix, HRESULT status);
diff --git a/src/stream.cpp b/src/stream.cpp
index 953342cd0f9..58c1a4ab0b0 100644
--- a/src/stream.cpp
+++ b/src/stream.cpp
@@ -331,6 +331,8 @@ namespace stream {
     udp::socket audio_sock { io_context };
 
     control_server_t control_server;
+
+    std::chrono::steady_clock::time_point av_timestamp_epoch;
   };
 
   struct session_t {
@@ -371,7 +373,6 @@ namespace stream {
       // avRiKeyId == util::endian::big(First (sizeof(avRiKeyId)) bytes of launch_session->iv)
       std::uint32_t avRiKeyId;
       std::uint16_t sequenceNumber;
-      std::chrono::steady_clock::time_point timestamp_epoch;
       udp::endpoint peer;
 
       util::buffer_t<char> shards;
@@ -1271,10 +1272,9 @@ namespace stream {
   }
 
   void
-  videoBroadcastThread(udp::socket &sock) {
+  videoBroadcastThread(udp::socket &sock, std::chrono::steady_clock::time_point av_timestamp_epoch) {
     auto shutdown_event = mail::man->event<bool>(mail::broadcast_shutdown);
     auto packets = mail::man->queue<video::packet_t>(mail::video_packets);
-    auto timestamp_epoch = std::chrono::steady_clock::now();
 
     // Video traffic is sent on this thread
     platf::adjust_thread_priority(platf::thread_priority_e::high);
@@ -1479,14 +1479,21 @@ namespace stream {
             auto *inspect = (video_packet_raw_t *) shards.data(x);
 
             // RTP video timestamps use a 90 KHz clock
+            static auto _last_frame_timestamp = std::chrono::steady_clock::now();
             auto timestamp = static_cast<std::uint32_t>(
               std::chrono::duration_cast<std::chrono::microseconds>(
                 packet->frame_timestamp
-                  ? *packet->frame_timestamp - timestamp_epoch
-                  : std::chrono::steady_clock::now() - timestamp_epoch // is this fallback needed?
+                  ? *packet->frame_timestamp - av_timestamp_epoch
+                  : _last_frame_timestamp - av_timestamp_epoch
               ).count() / (1000.0 / 90)
             );
 
+            if (packet->frame_timestamp) {
+              _last_frame_timestamp = *packet->frame_timestamp;
+            }
+
+            BOOST_LOG(verbose) << "Video [seq "sv << lowseq + x << ", pts "sv << timestamp << "] ::  send..."sv;
+
             inspect->packet.fecInfo =
               (x << 12 |
                 shards.data_shards << 22 |
@@ -1602,7 +1609,7 @@ namespace stream {
   }
 
   void
-  audioBroadcastThread(udp::socket &sock) {
+  audioBroadcastThread(udp::socket &sock, std::chrono::steady_clock::time_point av_timestamp_epoch) {
     auto shutdown_event = mail::man->event<bool>(mail::broadcast_shutdown);
     auto packets = mail::man->queue<audio::packet_t>(mail::audio_packets);
 
@@ -1630,28 +1637,29 @@ namespace stream {
         break;
       }
 
-      TUPLE_2D_REF(channel_data, packet_data, *packet);
-      auto session = (session_t *) channel_data;
-
+      auto session = (session_t *) packet->channel_data;
       auto sequenceNumber = session->audio.sequenceNumber;
-      // Audio timestamps are in milliseconds and should be AudioPacketDuration (5ms or 10ms) apart
+
+      // Audio timestamps are in milliseconds
       auto timestamp = static_cast<std::uint32_t>(
-        std::chrono::duration_cast<std::chrono::microseconds>(
-          std::chrono::steady_clock::now() - session->audio.timestamp_epoch
-        ).count() / 1000.0
+        std::chrono::duration_cast<std::chrono::milliseconds>(
+          packet->capture_timestamp - av_timestamp_epoch
+        ).count()
       );
 
       *(std::uint32_t *) iv.data() = util::endian::big<std::uint32_t>(session->audio.avRiKeyId + sequenceNumber);
 
       auto &shards_p = session->audio.shards_p;
 
-      auto bytes = encode_audio(session->config.encryptionFlagsEnabled & SS_ENC_AUDIO, packet_data,
+      auto bytes = encode_audio(session->config.encryptionFlagsEnabled & SS_ENC_AUDIO, packet->packet_data,
         shards_p[sequenceNumber % RTPA_DATA_SHARDS], iv, session->audio.cipher);
       if (bytes < 0) {
         BOOST_LOG(error) << "Couldn't encode audio packet"sv;
         break;
       }
 
+      BOOST_LOG(verbose) << "Audio [seq "sv << sequenceNumber << ", pts "sv << timestamp << "] ::  send..."sv;
+
       audio_packet.rtp.sequenceNumber = util::endian::big(sequenceNumber);
       audio_packet.rtp.timestamp = util::endian::big(timestamp);
 
@@ -1670,7 +1678,6 @@ namespace stream {
           session->localAddress,
         };
         platf::send(send_info);
-        BOOST_LOG(verbose) << "Audio ["sv << sequenceNumber << "] ::  send..."sv;
 
         auto &fec_packet = session->audio.fec_packet;
         // initialize the FEC header at the beginning of the FEC block
@@ -1762,10 +1769,13 @@ namespace stream {
       return -1;
     }
 
+    // The zero point for both audio & video RTP timestamps
+    ctx.av_timestamp_epoch = std::chrono::steady_clock::now();
+
     ctx.message_queue_queue = std::make_shared<message_queue_queue_t::element_type>(30);
 
-    ctx.video_thread = std::thread { videoBroadcastThread, std::ref(ctx.video_sock) };
-    ctx.audio_thread = std::thread { audioBroadcastThread, std::ref(ctx.audio_sock) };
+    ctx.video_thread = std::thread { videoBroadcastThread, std::ref(ctx.video_sock), ctx.av_timestamp_epoch };
+    ctx.audio_thread = std::thread { audioBroadcastThread, std::ref(ctx.audio_sock), ctx.av_timestamp_epoch };
     ctx.control_thread = std::thread { controlBroadcastThread, &ctx.control_server };
 
     ctx.recv_thread = std::thread { recvThread, std::ref(ctx) };
@@ -2076,7 +2086,6 @@ namespace stream {
       session->audio.ping_payload = launch_session.av_ping_payload;
       session->audio.avRiKeyId = util::endian::big(*(std::uint32_t *) launch_session.iv.data());
       session->audio.sequenceNumber = 0;
-      session->audio.timestamp_epoch = std::chrono::steady_clock::now();
 
       session->control.peer = nullptr;
       session->state.store(state_e::STOPPED, std::memory_order_relaxed);
diff --git a/tests/unit/test_audio.cpp b/tests/unit/test_audio.cpp
index 93ae0d80b14..4aa0e7f28f5 100644
--- a/tests/unit/test_audio.cpp
+++ b/tests/unit/test_audio.cpp
@@ -54,7 +54,7 @@ TEST_P(AudioTest, TestEncode) {
       if (shutdown_event->peek()) {
         break;
       }
-      auto packet_data = packet->second;
+      auto packet_data = packet->packet_data;
       if (packet_data.size() == 0) {
         FAIL() << "Empty packet data";
       }