diff --git a/README.md b/README.md index 0e1ba04..30a0f2c 100644 --- a/README.md +++ b/README.md @@ -152,16 +152,17 @@ It requires: You can also use `Sharding.startProxy` if you need to send messages to entities located on `other` nodes. -To send a message to a sharded entity, use `send`. To stop one, use `stop`. +To send a message to a sharded entity without expecting a response, use `send`. To send a message to a sharded entity expecting a response, use `ask`. To stop one, use `stop`. The `entityId` identifies the entity to target. Messages sent to the same `entityId` from different nodes in the cluster will be handled by the same actor. ```scala def send(entityId: String, data: M): Task[Unit] +def ask[R](entityId: String, data: M): Task[R] def stop(entityId: String): Task[Unit] ``` **Note on Serialization** -Akka messages are serialized when they are sent across the network. By default, Java serialization is used but it is not recommended to use it in production. +Akka messages are serialized when they are sent across the network. By default, Java serialization is used, but it is not recommended in production. See [Akka Documentation](https://doc.akka.io/docs/akka/current/serialization.html) to see how to provide your own serializer. This library wraps messages inside of a `zio.akka.cluster.sharding.MessageEnvelope` case class, so your serializer needs to cover it as well. diff --git a/src/main/scala/zio/akka/cluster/sharding/Entity.scala b/src/main/scala/zio/akka/cluster/sharding/Entity.scala index f6544da..540c3e5 100644 --- a/src/main/scala/zio/akka/cluster/sharding/Entity.scala +++ b/src/main/scala/zio/akka/cluster/sharding/Entity.scala @@ -1,8 +1,9 @@ package zio.akka.cluster.sharding -import zio.{ Ref, UIO } +import zio.{ Ref, Task, UIO } trait Entity[State] { + def replyToSender[R](msg: R): Task[Unit] def id: String def state: Ref[Option[State]] def stop: UIO[Unit] diff --git a/src/main/scala/zio/akka/cluster/sharding/Sharding.scala b/src/main/scala/zio/akka/cluster/sharding/Sharding.scala index 4afcd37..64cc2c0 100644 --- a/src/main/scala/zio/akka/cluster/sharding/Sharding.scala +++ b/src/main/scala/zio/akka/cluster/sharding/Sharding.scala @@ -1,10 +1,15 @@ package zio.akka.cluster.sharding import akka.actor.{ Actor, ActorContext, ActorRef, ActorSystem, PoisonPill, Props } +import akka.pattern.{ ask => askPattern } import akka.cluster.sharding.{ ClusterSharding, ClusterShardingSettings } +import akka.util.Timeout import zio.akka.cluster.sharding import zio.akka.cluster.sharding.MessageEnvelope.{ MessagePayload, PoisonPillPayload } -import zio.{ Has, Ref, Runtime, Task, UIO, ZIO } +import zio.{ =!=, Has, Ref, Runtime, Task, UIO, ZIO } + +import scala.concurrent.duration._ +import scala.reflect.ClassTag /** * A `Sharding[M]` is able to send messages of type `M` to a sharded entity or to stop one. @@ -15,6 +20,8 @@ trait Sharding[M] { def stop(entityId: String): Task[Unit] + def ask[R](entityId: String, data: M)(implicit tag: ClassTag[R], proof: R =!= Nothing): Task[R] + } object Sharding { @@ -25,12 +32,14 @@ object Sharding { * @param name the name of the entity type * @param onMessage the behavior of the entity when it receives a message * @param numberOfShards a fixed number of shards + * @param askTimeout a finite duration specifying how long an ask is allowed to wait for an entity to respond * @return a [[Sharding]] object that can be used to send messages to sharded entities */ def start[Msg, State]( name: String, onMessage: Msg => ZIO[Entity[State], Nothing, Unit], - numberOfShards: Int = 100 + numberOfShards: Int = 100, + askTimeout: FiniteDuration = 10.seconds ): ZIO[Has[ActorSystem], Throwable, Sharding[Msg]] = for { rts <- ZIO.runtime[Has[ActorSystem]] @@ -54,20 +63,23 @@ object Sharding { ) } yield new ShardingImpl[Msg] { override val getShardingRegion: ActorRef = shardingRegion + override implicit val timeout: Timeout = Timeout(askTimeout) } /** - * Starts cluster sharding in proxy mode for a given entity type. + * Starts cluster sharding in proxy mode for a given entity type. * - * @param name the name of the entity type - * @param role an optional role to specify that this entity type is located on cluster nodes with a specific role + * @param name the name of the entity type + * @param role an optional role to specify that this entity type is located on cluster nodes with a specific role * @param numberOfShards a fixed number of shards + * @param askTimeout a finite duration specifying how long an ask is allowed to wait for an entity to respond * @return a [[Sharding]] object that can be used to send messages to sharded entities on other nodes */ def startProxy[Msg]( name: String, role: Option[String], - numberOfShards: Int = 100 + numberOfShards: Int = 100, + askTimeout: FiniteDuration = 10.seconds ): ZIO[Has[ActorSystem], Throwable, Sharding[Msg]] = for { rts <- ZIO.runtime[Has[ActorSystem]] @@ -89,11 +101,12 @@ object Sharding { ) ) } yield new ShardingImpl[Msg] { + override val timeout: Timeout = Timeout(askTimeout) override val getShardingRegion: ActorRef = shardingRegion } private[sharding] trait ShardingImpl[Msg] extends Sharding[Msg] { - + implicit val timeout: Timeout val getShardingRegion: ActorRef override def send(entityId: String, data: Msg): Task[Unit] = @@ -101,6 +114,12 @@ object Sharding { override def stop(entityId: String): Task[Unit] = Task(getShardingRegion ! sharding.MessageEnvelope(entityId, PoisonPillPayload)) + + override def ask[R](entityId: String, data: Msg)(implicit tag: ClassTag[R], proof: R =!= Nothing): Task[R] = + Task.fromFuture(_ => + (getShardingRegion ? sharding.MessageEnvelope(entityId, MessagePayload(data))) + .mapTo[R] + ) } private[sharding] class ShardEntity[Msg, State](rts: Runtime[Any])( @@ -110,12 +129,13 @@ object Sharding { val ref: Ref[Option[State]] = rts.unsafeRun(Ref.make[Option[State]](None)) val actorContext: ActorContext = context val entity: Entity[State] = new Entity[State] { - override def id: String = context.self.path.name - override def state: Ref[Option[State]] = ref - override def stop: UIO[Unit] = UIO(actorContext.stop(self)) + override def id: String = context.self.path.name + override def state: Ref[Option[State]] = ref + override def stop: UIO[Unit] = UIO(actorContext.stop(self)) + override def replyToSender[R](msg: R): Task[Unit] = Task(context.sender() ! msg) } - def receive: PartialFunction[Any, Unit] = { + def receive: Receive = { case MessagePayload(msg) => rts.unsafeRunSync(onMessage(msg.asInstanceOf[Msg]).provide(entity)) () diff --git a/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala b/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala index 495a48f..4942ae6 100644 --- a/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala +++ b/src/test/scala/zio/akka/cluster/sharding/ShardingSpec.scala @@ -75,6 +75,16 @@ object ShardingSpec extends DefaultRunnableSpec { } yield res )(equalTo(msg)).provideLayer(actorSystem) }, + testM("send and receive a message using ask") { + val onMessage: String => ZIO[Entity[Any], Nothing, Unit] = + incomingMsg => ZIO.accessM[Entity[Any]](r => r.replyToSender(incomingMsg).orDie) + assertM( + for { + sharding <- Sharding.start(shardName, onMessage) + reply <- sharding.ask[String](shardId, msg) + } yield reply + )(equalTo(msg)).provideLayer(actorSystem) + }, testM("gather state") { assertM( for {