Skip to content

Commit

Permalink
MEDIA-2669: Fix SMB crash (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
RicardoMDomingues authored Jun 27, 2024
1 parent f58c712 commit 536df3a
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 31 deletions.
5 changes: 5 additions & 0 deletions bridge/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ Bridge::Bridge(const config::Config& config)

Bridge::~Bridge()
{
if (_httpd)
{
_httpd = nullptr;
}

if (_mixerManager)
{
_mixerManager->stop();
Expand Down
2 changes: 1 addition & 1 deletion test/integration/emulator/FakeTcpEndpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ void FakeTcpEndpoint::stop(Endpoint::IStopEvents* listener)
_receiveJobs.post([this, listener]() { listener->onEndpointStopped(this); });
}
}
else if (_state == State::CREATED)
else if (_state == State::CREATED || _state == State::STOPPING)
{
if (listener)
{
Expand Down
2 changes: 2 additions & 0 deletions test/transport/FakeNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ void Firewall::removePortMapping(Protocol protocol, transport::SocketAddress& la

void Firewall::block(const transport::SocketAddress& source, const transport::SocketAddress& destination)
{
std::lock_guard<std::mutex> lock(_nodesMutex);
if (isBlackListed(source, destination))
{
return;
Expand All @@ -436,6 +437,7 @@ void Firewall::block(const transport::SocketAddress& source, const transport::So

void Firewall::unblock(const transport::SocketAddress& source, const transport::SocketAddress& destination)
{
std::lock_guard<std::mutex> lock(_nodesMutex);
auto it = _blackList.find(std::pair<transport::SocketAddress, transport::SocketAddress>(source, destination));
if (it != _blackList.end())
{
Expand Down
9 changes: 6 additions & 3 deletions transport/TcpEndpointImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ RtpDepacketizer::RtpDepacketizer(int socketHandle, memory::PacketPoolAllocator&
: fd(socketHandle),
_receivedBytes(0),
_allocator(allocator),
_streamPrestine(true),
_streamPristine(true),
_remoteDisconnect(false)
{
}
Expand Down Expand Up @@ -75,7 +75,7 @@ memory::UniquePacket RtpDepacketizer::receive()
if (_header.get() >= memory::Packet::size)
{
// attack with malicious length specifier
_streamPrestine = false;
_streamPristine = false;
return memory::makeUniquePacket(_allocator);
}
}
Expand Down Expand Up @@ -355,6 +355,9 @@ void TcpEndpointImpl::stop(Endpoint::IStopEvents* listener)
}
else
{
// We do support stop(Endpoint::IStopEvents*) to be called multiple times
// but only one at a time. So we can't register multiple listeners in waiting state
assert(_stopListener == nullptr);
_stopListener = listener;
}
});
Expand All @@ -369,7 +372,7 @@ void TcpEndpointImpl::onSocketShutdown(int fd)
if (_depacketizer.fd == fd)
{
auto state = _state.load();
// Avoid race conditions whith endpoint stop(Endpoint::IStopEvents* listener)
// Avoid race conditions with endpoint stop(Endpoint::IStopEvents* listener)
// called when the TcpSocket is being deleted by SMB at the same time we receive
// a close request from remote side
while (state == State::CONNECTING || state == State::CONNECTED)
Expand Down
4 changes: 2 additions & 2 deletions transport/TcpEndpointImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class RtpDepacketizer

memory::UniquePacket receive();

bool isGood() const { return fd != -1 && _streamPrestine && !_remoteDisconnect; }
bool isGood() const { return fd != -1 && _streamPristine && !_remoteDisconnect; }
bool hasRemoteDisconnected() const { return _remoteDisconnect; }
void close();

Expand All @@ -27,7 +27,7 @@ class RtpDepacketizer
size_t _receivedBytes;
memory::UniquePacket _incompletePacket;
memory::PacketPoolAllocator& _allocator;
bool _streamPrestine;
bool _streamPristine;
bool _remoteDisconnect;
};

Expand Down
14 changes: 12 additions & 2 deletions transport/TransportFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,16 @@ class TransportFactoryImpl final : public TransportFactory,
}

logger::debug("closing %s", _name, endpoint->getName());
++_pendingTasks;
endpoint->stop(this);
if (endpoint->getState() == Endpoint::State::CREATED)
{
// When transport is CREATED we can delete it right now without call stop
enqueueDeleteJobNow(endpoint);
}
else
{
++_pendingTasks;
endpoint->stop(this);
}
}

void shutdownEndpoint(ServerEndpoint* endpoint)
Expand Down Expand Up @@ -608,6 +616,8 @@ class TransportFactoryImpl final : public TransportFactory,
--_pendingTasks; // epoll stop is complete
}

void enqueueDeleteJobNow(Endpoint* endpoint) { _garbageQueue.addJob<DeleteJob<Endpoint>>(endpoint, _pendingTasks); }

jobmanager::JobManager& _jobManager;
SrtpClientFactory& _srtpClientFactory;
const config::Config& _config;
Expand Down
63 changes: 41 additions & 22 deletions transport/TransportImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
#else
#define RTP_LOG(fmt, ...)
#endif

namespace transport
{
constexpr uint32_t Mbps100 = 100000;
Expand Down Expand Up @@ -1100,10 +1099,15 @@ void TransportImpl::internalIceReceived(Endpoint& endpoint,
}
else if (_rtpIceSession && endpoint.getTransportType() == ice::TransportType::TCP)
{
_rtpIceSession->onStunPacketReceived(&endpoint, source, packet->get(), packet->getLength(), timestamp);
if (dataReceiver)
// It is possible that a ice packet has been enqueued while we were discarding TCP candidates.
// Check if it CONNECTED to avoid IceSession to store a reference that will be dangling soon
if (Endpoint::State::CONNECTED == endpoint.getState())
{
dataReceiver->onIceReceived(this, timestamp);
_rtpIceSession->onStunPacketReceived(&endpoint, source, packet->get(), packet->getLength(), timestamp);
if (dataReceiver)
{
dataReceiver->onIceReceived(this, timestamp);
}
}
}
}
Expand All @@ -1113,6 +1117,26 @@ void TransportImpl::onTcpDisconnect(Endpoint& endpoint)
_jobQueue.post([&]() { _rtpIceSession->onTcpDisconnect(&endpoint); });
}

void TransportImpl::onEndpointStopped(Endpoint* endpoint)
{
_jobQueue.post(utils::bind(&TransportImpl::onTcpEndpointStoppedInternal, this, endpoint));
}

void TransportImpl::onTcpEndpointStoppedInternal(Endpoint* endpoint)
{
auto it = std::find_if(_rtpEndpoints.begin(), _rtpEndpoints.end(), [endpoint](const auto& ep) {
return static_cast<const transport::Endpoint*>(ep.get()) == endpoint;
});

if (it != _rtpEndpoints.end())
{
std::iter_swap(it, std::prev(_rtpEndpoints.end()));
_rtpEndpoints.pop_back();
}

--_jobCounter;
}

namespace
{
uint32_t processReportBlocks(const uint32_t count,
Expand Down Expand Up @@ -2039,25 +2063,20 @@ void TransportImpl::onIceStateChanged(ice::IceSession* session, const ice::IceSe

if (_rtpIceSession->getRole() == ice::IceRole::CONTROLLING)
{
while (!_rtpEndpoints.empty() && _rtpEndpoints.back().get() != _selectedRtp &&
_rtpEndpoints.back()->getTransportType() == ice::TransportType::TCP)
{
logger::info("discarding %s, ref %ld",
_loggableId.c_str(),
_rtpEndpoints.back()->getName(),
_rtpEndpoints.back().use_count());
_rtpIceSession->onTcpRemoved(_rtpEndpoints.back().get());
_rtpEndpoints.back()->unregisterListener(this);

_rtpEndpoints.pop_back();
}
while (!_rtpEndpoints.empty() && _rtpEndpoints.front().get() != _selectedRtp &&
_rtpEndpoints.front()->getTransportType() == ice::TransportType::TCP)
for (auto& endpoint : _rtpEndpoints)
{
logger::info("discarding %s", _loggableId.c_str(), _rtpEndpoints.front()->getName());
_rtpIceSession->onTcpRemoved(_rtpEndpoints.front().get());
_rtpEndpoints.front()->unregisterListener(this);
_rtpEndpoints.erase(_rtpEndpoints.begin());
if (endpoint.get() != _selectedRtp && endpoint->getTransportType() == ice::TransportType::TCP)
{
logger::info("discarding %s, ref %ld",
_loggableId.c_str(),
endpoint->getName(),
endpoint.use_count());

_rtpIceSession->onTcpRemoved(endpoint.get());
endpoint->unregisterListener(this);
++_jobCounter;
endpoint->stop(this);
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion transport/TransportImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class TransportImpl : public RtcTransport,
public sctp::SctpServerPort::IEvents,
public sctp::SctpAssociation::IEvents,
public ServerEndpoint::IEvents,
private RtcpReportProducer::RtcpSender
private RtcpReportProducer::RtcpSender,
private Endpoint::IStopEvents
{
public:
TransportImpl(jobmanager::JobManager& jobmanager,
Expand Down Expand Up @@ -238,6 +239,8 @@ class TransportImpl : public RtcTransport,
uint64_t timestamp) override;

void onTcpDisconnect(Endpoint& endpoint) override;
void onEndpointStopped(Endpoint* endpoint) override;
void onTcpEndpointStoppedInternal(Endpoint* endpoint);

void onIceTcpConnect(std::shared_ptr<Endpoint> endpoint,
const SocketAddress& source,
Expand Down

0 comments on commit 536df3a

Please sign in to comment.