Skip to content

Commit

Permalink
Add Kinesis Sink
Browse files Browse the repository at this point in the history
  • Loading branch information
colmsnowplow committed Nov 14, 2023
1 parent 913a387 commit 66c40d3
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import org.specs2.mutable.SpecificationLike
import org.testcontainers.containers.localstack.LocalStackContainer

import software.amazon.awssdk.services.kinesis.KinesisAsyncClient
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain

import com.snowplowanalytics.snowplow.sources.EventProcessingConfig
import com.snowplowanalytics.snowplow.sources.EventProcessingConfig.NoWindowing
Expand All @@ -35,7 +36,7 @@ class KinesisSourceSpec
/** Resources which are shared across tests */
override val resource: Resource[IO, (LocalStackContainer, KinesisAsyncClient, String => KinesisSourceConfig)] =
for {
region <- Resource.eval(KinesisSourceConfig.getRuntimeRegion[IO])
region <- Resource.eval(IO.blocking((new DefaultAwsRegionProviderChain).getRegion))
localstack <- Localstack.resource(region, KINESIS_INITIALIZE_STREAMS)
kinesisClient <- Resource.eval(getKinesisClient(localstack.getEndpoint, region))
} yield (localstack, kinesisClient, getKinesisConfig(localstack.getEndpoint)(_))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sinks.kinesis

import cats.implicits._
import cats.{Applicative, Monoid, Parallel}
import cats.effect.{Async, Resource, Sync}
import cats.effect.kernel.Ref

import org.typelevel.log4cats.{Logger, SelfAwareStructuredLogger}
import org.typelevel.log4cats.slf4j.Slf4jLogger
import retry.syntax.all._
import retry.{RetryPolicies, RetryPolicy}

import software.amazon.awssdk.core.SdkBytes
import software.amazon.awssdk.services.kinesis.KinesisClient
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain
import software.amazon.awssdk.services.kinesis.model.{PutRecordsRequest, PutRecordsRequestEntry, PutRecordsResponse}

import java.net.URI
import java.util.UUID
import java.nio.charset.StandardCharsets.UTF_8

import scala.jdk.CollectionConverters._

import com.snowplowanalytics.snowplow.sinks.{Sink, Sinkable}

object KinesisSink {

def resource[F[_]: Parallel: Async](config: KinesisSinkConfig): Resource[F, Sink[F]] =
mkProducer[F](config).map { p =>
Sink(
writeToKinesis[F](
config.throttledBackoffPolicy,
RequestLimits(config.recordLimit, config.byteLimit),
p,
config.streamName,
_
)
)
}

private implicit def logger[F[_]: Sync]: SelfAwareStructuredLogger[F] = Slf4jLogger.getLogger[F]

private def buildKinesisClient(customEndpoint: Option[URI], region: Region): KinesisClient = {
val builder = KinesisClient.builder().region(region).defaultsMode(DefaultsMode.AUTO)
customEndpoint.foreach(e => builder.endpointOverride(e))
builder.build()
}

private def mkProducer[F[_]: Sync](config: KinesisSinkConfig): Resource[F, KinesisClient] = {
val make = Sync[F].delay(buildKinesisClient(config.customEndpoint, (new DefaultAwsRegionProviderChain).getRegion()))

Resource.make(make) { producer =>
Sync[F].blocking {
producer.close()
}
}
}

/**
* This function takes a list of records and splits it into several lists, where each list is as
* big as possible with respecting the record limit and the size limit.
*/
private[kinesis] def group[A](
records: List[A],
recordLimit: Int,
sizeLimit: Int,
getRecordSize: A => Int
): List[List[A]] = {
case class Batch(
size: Int,
count: Int,
records: List[A]
)

records
.foldLeft(List.empty[Batch]) { case (acc, record) =>
val recordSize = getRecordSize(record)
acc match {
case head :: tail =>
if (head.count + 1 > recordLimit || head.size + recordSize > sizeLimit)
List(Batch(recordSize, 1, List(record))) ++ List(head) ++ tail
else
List(Batch(head.size + recordSize, head.count + 1, record :: head.records)) ++ tail
case Nil =>
List(Batch(recordSize, 1, List(record)))
}
}
.map(_.records)
}

private def putRecords(
kinesis: KinesisClient,
streamName: String,
records: List[PutRecordsRequestEntry]
): PutRecordsResponse = {
val putRecordsRequest = {
val prr = PutRecordsRequest.builder()
prr
.streamName(streamName)
.records(records.asJava)
prr.build()
}
kinesis.putRecords(putRecordsRequest)
}

private def toKinesisRecords(records: List[Sinkable]): List[PutRecordsRequestEntry] =
records.map { r =>
val data = SdkBytes.fromByteArrayUnsafe(r.bytes)
val prre = PutRecordsRequestEntry
.builder()
.partitionKey(r.partitionKey.getOrElse(UUID.randomUUID.toString()))
.data(data)
.build()
prre
}

/**
* The result of trying to write a batch to kinesis
* @param nextBatchAttempt
* Records to re-package into another batch, either because of throttling or an internal error
* @param hadNonThrottleErrors
* Whether at least one of failures is not because of throttling
* @param exampleInternalError
* A message to help with logging
*/
private case class TryBatchResult(
nextBatchAttempt: Vector[PutRecordsRequestEntry],
hadNonThrottleErrors: Boolean,
exampleInternalError: Option[String]
)

private object TryBatchResult {

implicit private def tryBatchResultMonoid: Monoid[TryBatchResult] =
new Monoid[TryBatchResult] {
override val empty: TryBatchResult = TryBatchResult(Vector.empty, false, None)
override def combine(x: TryBatchResult, y: TryBatchResult): TryBatchResult =
TryBatchResult(
x.nextBatchAttempt ++ y.nextBatchAttempt,
x.hadNonThrottleErrors || y.hadNonThrottleErrors,
x.exampleInternalError.orElse(y.exampleInternalError)
)
}

/**
* The build method creates a TryBatchResult, which:
*
* - Returns an empty list and false for hadNonThrottleErrors if everything was successful
* - Returns the list of failed requests and true for hadNonThrottleErrors if we encountered
* any errors that weren't throttles
* - Returns the list of failed requests and false for hadNonThrottleErrors if we encountered
* only throttling
*/
def build(records: List[PutRecordsRequestEntry], prr: PutRecordsResponse): TryBatchResult =
if (prr.failedRecordCount().toInt =!= 0)
records
.zip(prr.records().asScala)
.foldMap { case (orig, recordResult) =>
Option(recordResult.errorCode()) match {
// If the record had no error, treat as success
case None =>
TryBatchResult(Vector.empty, false, None)
// If it had a throughput exception, mark that and provide the original
case Some("ProvisionedThroughputExceededException") =>
TryBatchResult(Vector(orig), false, None)
// If any other error, mark success and throttled false for this record, and provide the original
case Some(_) =>
TryBatchResult(Vector(orig), true, Option(recordResult.errorMessage()))
}
}
else
TryBatchResult(Vector.empty, false, None)
}

/**
* Try writing a batch, and returns a list of the failures to be retried:
*
* If we are not throttled by kinesis, then the list is empty. If we are throttled by kinesis, the
* list contains throttled records and records that gave internal errors. If there is an
* exception, or if all records give internal errors, then we retry using the policy.
*/
private def tryWriteToKinesis[F[_]: Sync](
streamName: String,
kinesis: KinesisClient,
records: List[PutRecordsRequestEntry]
): F[Vector[PutRecordsRequestEntry]] =
Logger[F].debug(s"Writing ${records.size} records to ${streamName}") *>
Sync[F]
.blocking(putRecords(kinesis, streamName, records))
.map(TryBatchResult.build(records, _))
.flatMap { result =>
// If we encountered non-throttle errors, raise an exception. Otherwise, return all the requests that should
// be manually retried due to throttling
if (result.hadNonThrottleErrors)
Sync[F].raiseError(new RuntimeException(failureMessageForInternalErrors(records, streamName, result)))
else
result.nextBatchAttempt.pure[F]
}

private def writeToKinesis[F[_]: Parallel: Async](
throttlingErrorsPolicy: BackoffPolicy,
requestLimits: RequestLimits,
kinesis: KinesisClient,
streamName: String,
records: List[Sinkable]
): F[Unit] = {
val policyForThrottling = Retries.fibonacci[F](throttlingErrorsPolicy)

// First, tryWriteToKinesis - the AWS SDK will handle retries. If there are still failures after that, it will:
// - return messages for retries if we only hit throttliing
// - raise an error if we still have non-throttle failures after the SDK has carried out retries
def runAndCaptureFailures(ref: Ref[F, List[PutRecordsRequestEntry]]): F[List[PutRecordsRequestEntry]] =
for {
records <- ref.get
failures <- group(records, requestLimits.recordLimit, requestLimits.bytesLimit, getRecordSize)
.parTraverse(g => tryWriteToKinesis(streamName, kinesis, g))
flattened = failures.flatten
_ <- ref.set(flattened)
} yield flattened
for {
ref <- Ref.of[F, List[PutRecordsRequestEntry]](toKinesisRecords(records))
failures <- runAndCaptureFailures(ref)
.retryingOnFailures(
policy = policyForThrottling,
wasSuccessful = entries => Sync[F].pure(entries.isEmpty),
onFailure = { case (result, retryDetails) =>
val msg = failureMessageForThrottling(result, streamName)
Logger[F].warn(s"$msg (${retryDetails.retriesSoFar} retries from cats-retry)")
}
)
_ <- if (failures.isEmpty) Sync[F].unit
else Sync[F].raiseError(new RuntimeException(failureMessageForThrottling(failures, streamName)))
} yield ()
}

private final case class RequestLimits(recordLimit: Int, bytesLimit: Int)

private object Retries {

def fibonacci[F[_]: Applicative](config: BackoffPolicy): RetryPolicy[F] =
capBackoffAndRetries(config, RetryPolicies.fibonacciBackoff[F](config.minBackoff))

private def capBackoffAndRetries[F[_]: Applicative](config: BackoffPolicy, policy: RetryPolicy[F]): RetryPolicy[F] = {
val capped = RetryPolicies.capDelay[F](config.maxBackoff, policy)
config.maxRetries.fold(capped)(max => capped.join(RetryPolicies.limitRetries(max)))
}
}

private def getRecordSize(record: PutRecordsRequestEntry) =
record.data.asByteArrayUnsafe().length + record.partitionKey().getBytes(UTF_8).length

private def failureMessageForInternalErrors(
records: List[PutRecordsRequestEntry],
streamName: String,
result: TryBatchResult
): String = {
val exampleMessage = result.exampleInternalError.getOrElse("none")
s"Writing ${records.size} records to $streamName errored with internal failures. Example error message [$exampleMessage]"
}

private def failureMessageForThrottling(
records: List[PutRecordsRequestEntry],
streamName: String
): String =
s"Exceeded Kinesis provisioned throughput: ${records.size} records failed writing to $streamName."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sinks.kinesis

import io.circe._
import io.circe.generic.semiauto._
import io.circe.config.syntax._
import scala.concurrent.duration.FiniteDuration

import java.net.URI

case class BackoffPolicy(
minBackoff: FiniteDuration,
maxBackoff: FiniteDuration,
maxRetries: Option[Int]
)

object BackoffPolicy {

implicit def backoffPolicyDecoder: Decoder[BackoffPolicy] =
deriveDecoder[BackoffPolicy]
}

case class KinesisSinkConfig(
streamName: String,
throttledBackoffPolicy: BackoffPolicy,
recordLimit: Int,
byteLimit: Int,
customEndpoint: Option[URI]
)

object KinesisSinkConfig {
implicit def decoder: Decoder[KinesisSinkConfig] =
deriveDecoder[KinesisSinkConfig]
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import org.typelevel.log4cats.{Logger, SelfAwareStructuredLogger}
import org.typelevel.log4cats.slf4j.Slf4jLogger
import software.amazon.awssdk.awscore.defaultsmode.DefaultsMode
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain
import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient
Expand Down Expand Up @@ -114,7 +115,7 @@ object KinesisSource {
private def kinesisStream[F[_]: Async](config: KinesisSourceConfig): Stream[F, LowLevelEvents[Map[String, KinesisMetadata[F]]]] = {
val resources =
for {
region <- Resource.eval(KinesisSourceConfig.getRuntimeRegion)
region <- Resource.eval(Sync[F].delay((new DefaultAwsRegionProviderChain).getRegion))
consumerSettings <- Resource.pure[F, KinesisConsumerSettings](
KinesisConsumerSettings(
config.streamName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
*/
package com.snowplowanalytics.snowplow.sources.kinesis

import cats.effect.Sync

import eu.timepit.refined.types.all.PosInt

import io.circe._
import io.circe.generic.extras.semiauto.deriveConfiguredDecoder
import io.circe.generic.extras.Configuration

import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain

import java.net.URI
import java.time.Instant

Expand All @@ -36,9 +31,6 @@ object KinesisSourceConfig {

private implicit val posIntDecoder: Decoder[PosInt] = Decoder.decodeInt.emap(PosInt.from)

private[kinesis] def getRuntimeRegion[F[_]: Sync]: F[Region] =
Sync[F].blocking((new DefaultAwsRegionProviderChain).getRegion)

sealed trait InitialPosition

object InitialPosition {
Expand Down
2 changes: 1 addition & 1 deletion project/BuildSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ object BuildSettings {
ThisBuild / mimaFailOnNoPrevious := false,
mimaBinaryIssueFilters ++= Seq(),
Test / test := {
mimaReportBinaryIssues.value
val _ = mimaReportBinaryIssues.value
(Test / test).value
}
)
Expand Down
Loading

0 comments on commit 66c40d3

Please sign in to comment.