Skip to content

Commit

Permalink
Merge pull request #4133 from armanbilge/issue/4009
Browse files Browse the repository at this point in the history
Add `ownPoller`
  • Loading branch information
armanbilge authored Oct 29, 2024
2 parents 72fc141 + cef80fe commit ef28cf9
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,8 @@ abstract class PollingSystem {

/**
* Creates a new instance of the user-facing interface.
*
* @param access
* callback to obtain a thread-local `Poller`.
* @return
* an instance of the user-facing interface `Api`.
*/
def makeApi(access: (Poller => Unit) => Unit): Api
def makeApi(ctx: PollingContext[Poller]): Api

/**
* Creates a new instance of the thread-local data structure used for polling.
Expand Down Expand Up @@ -109,7 +104,24 @@ abstract class PollingSystem {

}

private object PollingSystem {
sealed trait PollingContext[P] {

/**
* Register a callback to obtain a thread-local `Poller`
*/
def accessPoller(cb: P => Unit): Unit

/**
* Returns `true` if it is safe to interact with this `Poller`. Implementors of this method
* may be best-effort: it is always safe to return `false`, so callers must have an adequate
* fallback for the non-owning case.
*/
def ownPoller(poller: P): Boolean
}

private[unsafe] trait UnsealedPollingContext[P] extends PollingContext[P]

object PollingSystem {

/**
* Type alias for a `PollingSystem` that has a specified `Poller` type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type

(
threadPool,
pollingSystem.makeApi(threadPool.accessPoller),
pollingSystem.makeApi(threadPool),
{ () =>
unregisterMBeans()
threadPool.shutdown()
Expand Down
119 changes: 77 additions & 42 deletions core/jvm/src/main/scala/cats/effect/unsafe/SelectorSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.util.control.NonFatal

import java.nio.channels.SelectableChannel
import java.nio.channels.spi.{AbstractSelector, SelectorProvider}
import java.util.Iterator

import SelectorSystem._

Expand All @@ -30,8 +31,8 @@ final class SelectorSystem private (provider: SelectorProvider) extends PollingS

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Selector =
new SelectorImpl(access, provider)
def makeApi(ctx: PollingContext[Poller]): Selector =
new SelectorImpl(ctx, provider)

def makePoller(): Poller = new Poller(provider.openSelector())

Expand Down Expand Up @@ -68,29 +69,21 @@ final class SelectorSystem private (provider: SelectorProvider) extends PollingS

val value = if (error ne null) Left(error) else Right(readyOps)

var head: CallbackNode = null
var prev: CallbackNode = null
var node = key.attachment().asInstanceOf[CallbackNode]
while (node ne null) {
val next = node.next
val callbacks = key.attachment().asInstanceOf[Callbacks]
val iter = callbacks.iterator()
while (iter.hasNext()) {
val node = iter.next()

if ((node.interest & readyOps) != 0) { // execute callback and drop this node
if ((node.interest & readyOps) != 0) { // drop this node and execute callback
node.remove()
val cb = node.callback
if (cb != null) {
cb(value)
polled = true
}
if (prev ne null) prev.next = next
} else { // keep this node
prev = node
if (head eq null)
head = node
}

node = next
}

key.attach(head) // if key was canceled this will null attachment
()
}

Expand All @@ -107,42 +100,38 @@ final class SelectorSystem private (provider: SelectorProvider) extends PollingS
}

final class SelectorImpl private[SelectorSystem] (
access: (Poller => Unit) => Unit,
ctx: PollingContext[Poller],
val provider: SelectorProvider
) extends Selector {

def select(ch: SelectableChannel, ops: Int): IO[Int] = IO.async { selectCb =>
IO.async_[CallbackNode] { cb =>
access { data =>
IO.async_[Option[IO[Unit]]] { cb =>
ctx.accessPoller { poller =>
try {
val selector = data.selector
val selector = poller.selector
val key = ch.keyFor(selector)

val node = if (key eq null) { // not yet registered on this selector
val node = new CallbackNode(ops, selectCb, null)
ch.register(selector, ops, node)
node
val cbs = new Callbacks
ch.register(selector, ops, cbs)
cbs.append(ops, selectCb)
} else { // existing key
// mixin the new interest
key.interestOps(key.interestOps() | ops)
val node =
new CallbackNode(ops, selectCb, key.attachment().asInstanceOf[CallbackNode])
key.attach(node)
node
val cbs = key.attachment().asInstanceOf[Callbacks]
cbs.append(ops, selectCb)
}

val cancel = IO {
if (ctx.ownPoller(poller))
node.remove()
else
node.clear()
}

cb(Right(node))
cb(Right(Some(cancel)))
} catch { case ex if NonFatal(ex) => cb(Left(ex)) }
}
}.map { node =>
Some {
IO {
// set all interest bits
node.interest = -1
// clear for gc
node.callback = null
}
}
}
}

Expand All @@ -161,9 +150,55 @@ object SelectorSystem {

def apply(): SelectorSystem = apply(SelectorProvider.provider())

private final class CallbackNode(
var interest: Int,
var callback: Either[Throwable, Int] => Unit,
var next: CallbackNode
)
private final class Callbacks {

private var head: Node = null
private var last: Node = null

def append(interest: Int, callback: Either[Throwable, Int] => Unit): Node = {
val node = new Node(interest, callback)
if (last ne null) {
last.next = node
node.prev = last
} else {
head = node
}
last = node
node
}

def iterator(): Iterator[Node] = new Iterator[Node] {
private var _next = head

def hasNext() = _next ne null

def next() = {
val next = _next
_next = next.next
next
}
}

final class Node(
var interest: Int,
var callback: Either[Throwable, Int] => Unit
) {
var prev: Node = null
var next: Node = null

def remove(): Unit = {
if (prev ne null) prev.next = next
else head = next

if (next ne null) next.prev = prev
else last = prev
}

def clear(): Unit = {
interest = -1 // set all interest bits
callback = null // clear for gc
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object SleepSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Api = this
def makeApi(ctx: PollingContext[Poller]): Api = this

def makePoller(): Poller = this

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ import WorkStealingThreadPool._
* contention. Work stealing is tried using a linear search starting from a random worker thread
* index.
*/
private[effect] final class WorkStealingThreadPool[P](
private[effect] final class WorkStealingThreadPool[P <: AnyRef](
threadCount: Int, // number of worker threads
private[unsafe] val threadPrefix: String, // prefix for the name of worker threads
private[unsafe] val blockerThreadPrefix: String, // prefix for the name of worker threads currently in a blocking region
Expand All @@ -71,7 +71,8 @@ private[effect] final class WorkStealingThreadPool[P](
system: PollingSystem.WithPoller[P],
reportFailure0: Throwable => Unit
) extends ExecutionContextExecutor
with Scheduler {
with Scheduler
with UnsealedPollingContext[P] {

import TracingConstants._
import WorkStealingThreadPoolConstants._
Expand All @@ -87,7 +88,7 @@ private[effect] final class WorkStealingThreadPool[P](
private[unsafe] val pollers: Array[P] =
new Array[AnyRef](threadCount).asInstanceOf[Array[P]]

private[unsafe] def accessPoller(cb: P => Unit): Unit = {
def accessPoller(cb: P => Unit): Unit = {

// figure out where we are
val thread = Thread.currentThread()
Expand All @@ -101,6 +102,14 @@ private[effect] final class WorkStealingThreadPool[P](
} else scheduleExternal(() => accessPoller(cb))
}

def ownPoller(poller: P): Boolean = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread[_]]) {
val worker = thread.asInstanceOf[WorkerThread[P]]
worker.ownsPoller(poller)
} else false
}

/**
* Atomic variable for used for publishing changes to the references in the `workerThreads`
* array. Worker threads can be changed whenever blocking code is encountered on the pool.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import java.util.concurrent.atomic.AtomicBoolean
* system when compared to a fixed size thread pool whose worker threads all draw tasks from a
* single global work queue.
*/
private final class WorkerThread[P](
private final class WorkerThread[P <: AnyRef](
idx: Int,
// Local queue instance with exclusive write access.
private[this] var queue: LocalQueue,
Expand Down Expand Up @@ -291,6 +291,9 @@ private final class WorkerThread[P](
foreign.toMap
}

private[unsafe] def ownsPoller(poller: P): Boolean =
poller eq _poller

private[unsafe] def ownsTimers(timers: TimerHeap): Boolean =
sleepers eq timers

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ object EpollSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): Api =
new FileDescriptorPollerImpl(access)
def makeApi(ctx: PollingContext[Poller]): Api =
new FileDescriptorPollerImpl(ctx)

def makePoller(): Poller = {
val fd = epoll_create1(0)
Expand All @@ -67,7 +67,7 @@ object EpollSystem extends PollingSystem {
def interrupt(targetThread: Thread, targetPoller: Poller): Unit = ()

private final class FileDescriptorPollerImpl private[EpollSystem] (
access: (Poller => Unit) => Unit)
ctx: PollingContext[Poller])
extends FileDescriptorPoller {

def registerFileDescriptor(
Expand All @@ -78,7 +78,7 @@ object EpollSystem extends PollingSystem {
Resource {
(Mutex[IO], Mutex[IO]).flatMapN { (readMutex, writeMutex) =>
IO.async_[(PollHandle, IO[Unit])] { cb =>
access { epoll =>
ctx.accessPoller { epoll =>
val handle = new PollHandle(readMutex, writeMutex)
epoll.register(fd, reads, writes, handle, cb)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
): (ExecutionContext with Scheduler, system.Api, () => Unit) = {
val loop = new EventLoopExecutorScheduler[system.Poller](64, system)
val poller = loop.poller
(loop, system.makeApi(cb => cb(poller)), () => loop.shutdown())
val api = system.makeApi(
new UnsealedPollingContext[system.Poller] {
def accessPoller(cb: system.Poller => Unit) = cb(poller)
def ownPoller(poller: system.Poller) = true
}
)
(loop, api, () => loop.shutdown())
}

def createDefaultPollingSystem(): PollingSystem =
Expand Down
14 changes: 7 additions & 7 deletions core/native/src/main/scala/cats/effect/unsafe/KqueueSystem.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ object KqueueSystem extends PollingSystem {

def close(): Unit = ()

def makeApi(access: (Poller => Unit) => Unit): FileDescriptorPoller =
new FileDescriptorPollerImpl(access)
def makeApi(ctx: PollingContext[Poller]): FileDescriptorPoller =
new FileDescriptorPollerImpl(ctx)

def makePoller(): Poller = {
val fd = kqueue()
Expand All @@ -67,7 +67,7 @@ object KqueueSystem extends PollingSystem {
def interrupt(targetThread: Thread, targetPoller: Poller): Unit = ()

private final class FileDescriptorPollerImpl private[KqueueSystem] (
access: (Poller => Unit) => Unit
ctx: PollingContext[Poller]
) extends FileDescriptorPoller {
def registerFileDescriptor(
fd: Int,
Expand All @@ -76,7 +76,7 @@ object KqueueSystem extends PollingSystem {
): Resource[IO, FileDescriptorPollHandle] =
Resource.eval {
(Mutex[IO], Mutex[IO]).mapN {
new PollHandle(access, fd, _, _)
new PollHandle(ctx, fd, _, _)
}
}
}
Expand All @@ -86,7 +86,7 @@ object KqueueSystem extends PollingSystem {
(filter.toLong << 32) | ident.toLong

private final class PollHandle(
access: (Poller => Unit) => Unit,
ctx: PollingContext[Poller],
fd: Int,
readMutex: Mutex[IO],
writeMutex: Mutex[IO]
Expand All @@ -101,7 +101,7 @@ object KqueueSystem extends PollingSystem {
else
IO.async[Unit] { kqcb =>
IO.async_[Option[IO[Unit]]] { cb =>
access { kqueue =>
ctx.accessPoller { kqueue =>
kqueue.evSet(fd, EVFILT_READ, EV_ADD.toUShort, kqcb)
cb(Right(Some(IO(kqueue.removeCallback(fd, EVFILT_READ)))))
}
Expand All @@ -121,7 +121,7 @@ object KqueueSystem extends PollingSystem {
else
IO.async[Unit] { kqcb =>
IO.async_[Option[IO[Unit]]] { cb =>
access { kqueue =>
ctx.accessPoller { kqueue =>
kqueue.evSet(fd, EVFILT_WRITE, EV_ADD.toUShort, kqcb)
cb(Right(Some(IO(kqueue.removeCallback(fd, EVFILT_WRITE)))))
}
Expand Down
Loading

0 comments on commit ef28cf9

Please sign in to comment.