From 336862598fc1c4daec9f95a58cd77deabf9bc1be Mon Sep 17 00:00:00 2001 From: Ian Streeter Date: Mon, 9 Sep 2024 14:34:56 +0100 Subject: [PATCH] [amendment] explicitly disallow >1 record processors for the same shard --- .../sources/kinesis/KinesisSource.scala | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala b/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala index d470a7e..63b29aa 100644 --- a/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala +++ b/modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/sources/kinesis/KinesisSource.scala @@ -42,6 +42,7 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber import java.net.URI import java.util.Date import java.util.concurrent.{CountDownLatch, SynchronousQueue} +import java.util.concurrent.atomic.AtomicReference import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ @@ -185,16 +186,27 @@ object KinesisSource { private def runRecordProcessor[F[_]: Async](scheduler: Scheduler): Resource[F, Unit] = Sync[F].blocking(scheduler.run()).background *> Resource.onFinalize(Sync[F].blocking(scheduler.shutdown())) - private def shardRecordProcessor(queue: SynchronousQueue[KCLAction]): ShardRecordProcessor = new ShardRecordProcessor { + private def shardRecordProcessor( + queue: SynchronousQueue[KCLAction], + currentShardIds: AtomicReference[Set[String]] + ): ShardRecordProcessor = new ShardRecordProcessor { private var shardId: String = _ - def initialize(initializationInput: InitializationInput): Unit = + def initialize(initializationInput: InitializationInput): Unit = { shardId = initializationInput.shardId + val oldSet = currentShardIds.getAndUpdate(_ + shardId) + if (oldSet.contains(shardId)) { + val action = KCLError(new RuntimeException(s"Refusing to initialize a duplicate record processor for shard $shardId")) + queue.put(action) + } + } def shardEnded(shardEndedInput: ShardEndedInput): Unit = { val countDownLatch = new CountDownLatch(1) queue.put(ShardEnd(shardId, countDownLatch, shardEndedInput)) countDownLatch.await() + currentShardIds.updateAndGet(_ - shardId) + () } def processRecords(processRecordsInput: ProcessRecordsInput): Unit = { @@ -202,13 +214,19 @@ object KinesisSource { queue.put(action) } - def leaseLost(leaseLostInput: LeaseLostInput): Unit = () + def leaseLost(leaseLostInput: LeaseLostInput): Unit = { + currentShardIds.updateAndGet(_ - shardId) + () + } def shutdownRequested(shutdownRequestedInput: ShutdownRequestedInput): Unit = () } - private def recordProcessorFactory(queue: SynchronousQueue[KCLAction]): ShardRecordProcessorFactory = { () => - shardRecordProcessor(queue) + private def recordProcessorFactory( + queue: SynchronousQueue[KCLAction], + currentShardIds: AtomicReference[Set[String]] + ): ShardRecordProcessorFactory = { () => + shardRecordProcessor(queue, currentShardIds) } private def initialPositionOf(config: KinesisSourceConfig.InitialPosition): InitialPositionInStreamExtended = @@ -236,7 +254,7 @@ object KinesisSource { dynamoDbClient, cloudWatchClient, kinesisConfig.workerIdentifier, - recordProcessorFactory(queue) + recordProcessorFactory(queue, new AtomicReference(Set.empty[String])) ) val retrievalConfig =