Skip to content

Commit

Permalink
Add Kinesis Sink integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
colmsnowplow committed Nov 6, 2023
1 parent a2db51a commit 0b55503
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd. All rights reserved.
*
* This program is licensed to you under the Snowplow Community License Version 1.0,
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sinks.kinesis

import cats.effect.{IO, Resource}
import cats.effect.testing.specs2.CatsResource

import scala.concurrent.duration.{DurationInt, FiniteDuration}

import org.specs2.mutable.SpecificationLike

import org.testcontainers.containers.localstack.LocalStackContainer

import software.amazon.awssdk.services.kinesis.KinesisAsyncClient
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain

import com.snowplowanalytics.snowplow.it.kinesis._
import com.snowplowanalytics.snowplow.sinks.{Sink, Sinkable}

import Utils._

class KinesisSinkSpec extends CatsResource[IO, (String, LocalStackContainer, KinesisAsyncClient, Sink[IO])] with SpecificationLike {
import KinesisSinkSpec._

override val Timeout: FiniteDuration = 3.minutes

/** Resources which are shared across tests */
override val resource: Resource[IO, (String, LocalStackContainer, KinesisAsyncClient, Sink[IO])] =
for {
region <- Resource.eval(IO.blocking((new DefaultAwsRegionProviderChain).getRegion))
localstack <- Localstack.resource(region, KINESIS_INITIALIZE_STREAMS, KinesisSinkSpec.getClass.getSimpleName)
kinesisClient <- Resource.eval(getKinesisClient(localstack.getEndpoint, region))
testSink <- KinesisSink.resource[IO](getKinesisSinkConfig(localstack.getEndpoint)(testStream1Name))
} yield (region.toString, localstack, kinesisClient, testSink)

override def is = s2"""
KinesisSinkSpec should
write to output stream $e1
"""

def e1 = withResource { case (region, _, kinesisClient, testSink) =>
val testPayload = "test-payload"
val testInput = List(Sinkable(testPayload.getBytes(), Some("myPk"), Map(("", ""))))

for {
_ <- testSink.sink(testInput)
_ <- IO.sleep(3.seconds)
result = getDataFromKinesis(kinesisClient, region, testStream1Name)
} yield List(
result.events must haveSize(1),
result.events must haveSize(1),
result.events must beEqualTo(List(testPayload))
)
}
}

object KinesisSinkSpec {
val testStream1Name = "test-sink-stream-1"
val KINESIS_INITIALIZE_STREAMS: String =
List(s"$testStream1Name:1").mkString(",")
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain

import com.snowplowanalytics.snowplow.sources.EventProcessingConfig
import com.snowplowanalytics.snowplow.sources.EventProcessingConfig.NoWindowing
import com.snowplowanalytics.snowplow.it.kinesis._

import java.time.Instant

Expand All @@ -37,16 +38,16 @@ class KinesisSourceSpec
override val resource: Resource[IO, (LocalStackContainer, KinesisAsyncClient, String => KinesisSourceConfig)] =
for {
region <- Resource.eval(IO.blocking((new DefaultAwsRegionProviderChain).getRegion))
localstack <- Localstack.resource(region, KINESIS_INITIALIZE_STREAMS)
localstack <- Localstack.resource(region, KINESIS_INITIALIZE_STREAMS, KinesisSourceSpec.getClass.getSimpleName)
kinesisClient <- Resource.eval(getKinesisClient(localstack.getEndpoint, region))
} yield (localstack, kinesisClient, getKinesisConfig(localstack.getEndpoint)(_))
} yield (localstack, kinesisClient, getKinesisSourceConfig(localstack.getEndpoint)(_))

override def is = s2"""
KinesisSourceSpec should
read from input stream $e1
"""

def e1 = withResource { case (_, kinesisClient, getKinesisConfig) =>
def e1 = withResource { case (_, kinesisClient, getKinesisSourceConfig) =>
val testPayload = "test-payload"

for {
Expand All @@ -55,7 +56,7 @@ class KinesisSourceSpec
_ <- putDataToKinesis(kinesisClient, testStream1Name, testPayload)
t2 <- IO.realTimeInstant
processingConfig = new EventProcessingConfig(NoWindowing)
kinesisConfig = getKinesisConfig(testStream1Name)
kinesisConfig = getKinesisSourceConfig(testStream1Name)
sourceAndAck <- KinesisSource.build[IO](kinesisConfig)
stream = sourceAndAck.stream(processingConfig, testProcessor(refProcessed))
fiber <- stream.compile.drain.start
Expand All @@ -73,7 +74,7 @@ class KinesisSourceSpec
}

object KinesisSourceSpec {
val testStream1Name = "test-stream-1"
val testStream1Name = "test-source-stream-1"
val KINESIS_INITIALIZE_STREAMS: String =
List(s"$testStream1Name:1").mkString(",")
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sources.kinesis
package com.snowplowanalytics.snowplow.it.kinesis

import cats.effect.{IO, Resource}
import org.slf4j.LoggerFactory
Expand All @@ -17,19 +17,24 @@ import software.amazon.awssdk.regions.Region

object Localstack {

def resource(region: Region, kinesisInitializeStreams: String): Resource[IO, LocalStackContainer] =
def resource(
region: Region,
kinesisInitializeStreams: String,
loggerName: String
): Resource[IO, LocalStackContainer] =
Resource.make {
val localstack = new LocalStackContainer(DockerImageName.parse("localstack/localstack:2.2.0"))
localstack.addEnv("AWS_DEFAULT_REGION", region.id)
localstack.addEnv("KINESIS_INITIALIZE_STREAMS", kinesisInitializeStreams)
localstack.addEnv("DEBUG", "1")
localstack.addExposedPort(4566)
localstack.setWaitStrategy(Wait.forLogMessage(".*Ready.*", 1))
IO(startLocalstack(localstack))
IO(startLocalstack(localstack, loggerName))
}(ls => IO.blocking(ls.stop()))

private def startLocalstack(localstack: LocalStackContainer): LocalStackContainer = {
private def startLocalstack(localstack: LocalStackContainer, loggerName: String): LocalStackContainer = {
localstack.start()
val logger = LoggerFactory.getLogger(KinesisSourceSpec.getClass.getSimpleName)
val logger = LoggerFactory.getLogger(loggerName)
val logs = new Slf4jLogConsumer(logger)
localstack.followOutput(logs)
localstack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,33 @@
* and you may not use this file except in compliance with the Snowplow Community License Version 1.0.
* You may obtain a copy of the Snowplow Community License Version 1.0 at https://docs.snowplow.io/community-license-1.0
*/
package com.snowplowanalytics.snowplow.sources.kinesis
package com.snowplowanalytics.snowplow.it.kinesis

import cats.effect.{IO, Ref}

import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration
import scala.jdk.CollectionConverters._
import scala.jdk.FutureConverters._
import scala.concurrent.Await

import eu.timepit.refined.types.numeric.PosInt

import software.amazon.awssdk.core.SdkBytes
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient
import software.amazon.awssdk.services.kinesis.model.{PutRecordRequest, PutRecordResponse}
import software.amazon.awssdk.services.kinesis.model.{GetRecordsRequest, PutRecordRequest, PutRecordResponse}

import com.snowplowanalytics.snowplow.sources.{EventProcessor, TokenedEvents}
import com.snowplowanalytics.snowplow.sources.kinesis.KinesisSourceConfig
import com.snowplowanalytics.snowplow.sinks.kinesis.{BackoffPolicy, KinesisSinkConfig}
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest

import java.net.URI
import java.nio.charset.StandardCharsets
import java.util.UUID
import java.time.Instant

import com.snowplowanalytics.snowplow.sources.{EventProcessor, TokenedEvents}
import java.util.concurrent.TimeUnit

object Utils {

Expand All @@ -42,7 +52,39 @@ object Utils {
IO.blocking(client.putRecord(record).get())
}

def getKinesisConfig(endpoint: URI)(streamName: String): KinesisSourceConfig = KinesisSourceConfig(
/**
* getDataFromKinesis gets the last 1000 records from kinesis, stringifies the datta it found, and returns a ReceivedEvents
* It can be called at the end of simple tests to return data from a Kinesis stream.
*
* If required in future, where more data is used we might amend it to poll the stream for data and return everything it finds after a period without any data.
*/
def getDataFromKinesis(
client: KinesisAsyncClient,
region: String,
streamName: String
): ReceivedEvents = {

val shIterRequest = GetShardIteratorRequest
.builder()
.streamName("test-sink-stream-1")
.shardIteratorType("TRIM_HORIZON")
.shardId("shardId-000000000000")
.build();

val shIter = Await.result(client.getShardIterator(shIterRequest).asScala, Duration("5 seconds")).shardIterator()

val request = GetRecordsRequest
.builder()
.streamARN("arn:aws:kinesis:%s:000000000000:stream/%s".format(region, streamName))
.shardIterator(shIter)
.build()

val out =
ReceivedEvents(client.getRecords(request).get().records().asScala.toList.map(record => new String(record.data.asByteArray())), None)
out
}

def getKinesisSourceConfig(endpoint: URI)(streamName: String): KinesisSourceConfig = KinesisSourceConfig(
UUID.randomUUID().toString,
streamName,
KinesisSourceConfig.InitialPosition.TrimHorizon,
Expand All @@ -53,6 +95,15 @@ object Utils {
Some(endpoint)
)

def getKinesisSinkConfig(endpoint: URI)(streamName: String): KinesisSinkConfig = KinesisSinkConfig(
streamName,
BackoffPolicy(FiniteDuration(1, TimeUnit.SECONDS), FiniteDuration(1, TimeUnit.SECONDS), None),
BackoffPolicy(FiniteDuration(1, TimeUnit.SECONDS), FiniteDuration(1, TimeUnit.SECONDS), None),
1000,
1000000,
Some(endpoint)
)

def testProcessor(ref: Ref[IO, List[ReceivedEvents]]): EventProcessor[IO] =
_.evalMap { case TokenedEvents(events, token, tstamp) =>
val parsed = events.map(byteBuffer => StandardCharsets.UTF_8.decode(byteBuffer).toString)
Expand Down

0 comments on commit 0b55503

Please sign in to comment.