diff --git a/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala b/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala index f3dbf0c..988b1dd 100644 --- a/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala +++ b/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala @@ -8,6 +8,7 @@ package com.snowplowanalytics.snowplow.sources.kinesis import cats.effect.{Async, Ref, Sync} +import cats.data.NonEmptyList import cats.implicits._ import com.snowplowanalytics.snowplow.sources.SourceAndAck import com.snowplowanalytics.snowplow.sources.internal.{LowLevelEvents, LowLevelSource} @@ -48,32 +49,68 @@ object KinesisSource { val actionQueue = new SynchronousQueue[KCLAction]() for { _ <- Stream.resource(KCLScheduler.populateQueue[F](config, actionQueue)) - events <- Stream.emit(pullFromQueue(actionQueue, liveness).stream).repeat + events <- Stream.emit(pullFromQueueAndEmit(actionQueue, liveness).stream).repeat } yield events } - private def pullFromQueue[F[_]: Sync]( + private def pullFromQueueAndEmit[F[_]: Sync]( queue: SynchronousQueue[KCLAction], liveness: Ref[F, FiniteDuration] ): Pull[F, LowLevelEvents[Map[String, Checkpointable]], Unit] = - Pull.eval(resolveNextAction(queue, liveness)).flatMap { - case KCLAction.ProcessRecords(_, processRecordsInput) if processRecordsInput.records.asScala.isEmpty => - pullFromQueue[F](queue, liveness) - case KCLAction.ProcessRecords(shardId, processRecordsInput) => - Pull.output1(provideNextChunk(shardId, processRecordsInput)).covary[F] *> pullFromQueue[F](queue, liveness) - case KCLAction.ShardEnd(shardId, await, shardEndedInput) => - handleShardEnd[F](shardId, await, shardEndedInput) *> Pull.done - case KCLAction.KCLError(t) => - Pull.eval(Logger[F].error(t)("Exception from Kinesis source")) *> Pull.raiseError[F](t) + Pull.eval(pullFromQueue(queue, liveness)).flatMap { case PullFromQueueResult(actions, hasShardEnd) => + val toEmit = actions.traverse { + case KCLAction.ProcessRecords(_, processRecordsInput) if processRecordsInput.records.asScala.isEmpty => + Pull.done + case KCLAction.ProcessRecords(shardId, processRecordsInput) => + Pull.output1(provideNextChunk(shardId, processRecordsInput)).covary[F] + case KCLAction.ShardEnd(shardId, await, shardEndedInput) => + handleShardEnd[F](shardId, await, shardEndedInput) + case KCLAction.KCLError(t) => + Pull.eval(Logger[F].error(t)("Exception from Kinesis source")) *> Pull.raiseError[F](t) + } + if (hasShardEnd) { + val log = Logger[F].info { + actions + .collect { case KCLAction.ShardEnd(shardId, _, _) => + shardId + } + .mkString("Ending this window of events early because reached the end of Kinesis shards: ", ",", "") + } + Pull.eval(log).covaryOutput *> toEmit *> Pull.done + } else + toEmit *> pullFromQueueAndEmit(queue, liveness) } - private def resolveNextAction[F[_]: Sync](queue: SynchronousQueue[KCLAction], liveness: Ref[F, FiniteDuration]): F[KCLAction] = { - val nextAction = Sync[F].delay(Option[KCLAction](queue.poll)).flatMap { + private case class PullFromQueueResult(actions: NonEmptyList[KCLAction], hasShardEnd: Boolean) + + private def pullFromQueue[F[_]: Sync](queue: SynchronousQueue[KCLAction], liveness: Ref[F, FiniteDuration]): F[PullFromQueueResult] = + resolveNextAction(queue) + .productL(updateLiveness(liveness)) + .flatMap { + case shardEnd: KCLAction.ShardEnd => + // If we reached the end of one shard, it is likely we reached the end of other shards too. + // Therefore pull more actions from the queue, to minimize the number of times we need to do + // an early close of the inner stream. + resolveAllActions(queue).map { more => + PullFromQueueResult(NonEmptyList(shardEnd, more), hasShardEnd = true) + } + case other => + PullFromQueueResult(NonEmptyList.one(other), hasShardEnd = false).pure[F] + } + + /** Always returns a `KCLAction`, possibly waiting until one is available */ + private def resolveNextAction[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[KCLAction] = + Sync[F].delay(Option[KCLAction](queue.poll)).flatMap { case Some(action) => Sync[F].pure(action) case None => Sync[F].interruptible(queue.take) } - nextAction <* updateLiveness(liveness) - } + + /** Returns immediately, but the `List[KCLAction]` might be empty */ + private def resolveAllActions[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[List[KCLAction]] = + for { + ret <- Sync[F].delay(new java.util.ArrayList[KCLAction]()) + _ <- Sync[F].delay(queue.drainTo(ret)) + } yield ret.asScala.toList private def updateLiveness[F[_]: Sync](liveness: Ref[F, FiniteDuration]): F[Unit] = Sync[F].realTime.flatMap(now => liveness.set(now)) @@ -89,17 +126,14 @@ object KinesisSource { LowLevelEvents(chunk, Map[String, Checkpointable](shardId -> checkpointable), Some(firstRecord.approximateArrivalTimestamp)) } - private def handleShardEnd[F[_]: Sync]( + private def handleShardEnd[F[_]]( shardId: String, await: CountDownLatch, shardEndedInput: ShardEndedInput - ) = { + ): Pull[F, LowLevelEvents[Map[String, Checkpointable]], Unit] = { val checkpointable = Checkpointable.ShardEnd(shardEndedInput.checkpointer, await) val last = LowLevelEvents(Chunk.empty, Map[String, Checkpointable](shardId -> checkpointable), None) - Pull - .eval(Logger[F].info(s"Ending this window of events early because reached the end of Kinesis shard $shardId")) - .covaryOutput *> - Pull.output1(last).covary[F] + Pull.output1(last) } }