Skip to content

Commit

Permalink
[amendment] explicitly disallow >1 record processors for the same shard
Browse files Browse the repository at this point in the history
  • Loading branch information
istreeter committed Sep 9, 2024
1 parent e983e5a commit 3368625
Showing 1 changed file with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber
import java.net.URI
import java.util.Date
import java.util.concurrent.{CountDownLatch, SynchronousQueue}
import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -185,30 +186,47 @@ object KinesisSource {
private def runRecordProcessor[F[_]: Async](scheduler: Scheduler): Resource[F, Unit] =
Sync[F].blocking(scheduler.run()).background *> Resource.onFinalize(Sync[F].blocking(scheduler.shutdown()))

private def shardRecordProcessor(queue: SynchronousQueue[KCLAction]): ShardRecordProcessor = new ShardRecordProcessor {
private def shardRecordProcessor(
queue: SynchronousQueue[KCLAction],
currentShardIds: AtomicReference[Set[String]]
): ShardRecordProcessor = new ShardRecordProcessor {
private var shardId: String = _

def initialize(initializationInput: InitializationInput): Unit =
def initialize(initializationInput: InitializationInput): Unit = {
shardId = initializationInput.shardId
val oldSet = currentShardIds.getAndUpdate(_ + shardId)
if (oldSet.contains(shardId)) {
val action = KCLError(new RuntimeException(s"Refusing to initialize a duplicate record processor for shard $shardId"))
queue.put(action)
}
}

def shardEnded(shardEndedInput: ShardEndedInput): Unit = {
val countDownLatch = new CountDownLatch(1)
queue.put(ShardEnd(shardId, countDownLatch, shardEndedInput))
countDownLatch.await()
currentShardIds.updateAndGet(_ - shardId)
()
}

def processRecords(processRecordsInput: ProcessRecordsInput): Unit = {
val action = ProcessRecords(shardId, processRecordsInput)
queue.put(action)
}

def leaseLost(leaseLostInput: LeaseLostInput): Unit = ()
def leaseLost(leaseLostInput: LeaseLostInput): Unit = {
currentShardIds.updateAndGet(_ - shardId)
()
}

def shutdownRequested(shutdownRequestedInput: ShutdownRequestedInput): Unit = ()
}

private def recordProcessorFactory(queue: SynchronousQueue[KCLAction]): ShardRecordProcessorFactory = { () =>
shardRecordProcessor(queue)
private def recordProcessorFactory(
queue: SynchronousQueue[KCLAction],
currentShardIds: AtomicReference[Set[String]]
): ShardRecordProcessorFactory = { () =>
shardRecordProcessor(queue, currentShardIds)
}

private def initialPositionOf(config: KinesisSourceConfig.InitialPosition): InitialPositionInStreamExtended =
Expand Down Expand Up @@ -236,7 +254,7 @@ object KinesisSource {
dynamoDbClient,
cloudWatchClient,
kinesisConfig.workerIdentifier,
recordProcessorFactory(queue)
recordProcessorFactory(queue, new AtomicReference(Set.empty[String]))
)

val retrievalConfig =
Expand Down

0 comments on commit 3368625

Please sign in to comment.