From e9795bb174b10a5a371b3a483285bb6bebce70ed Mon Sep 17 00:00:00 2001 From: Dmytro Nosan Date: Fri, 25 Oct 2024 15:00:10 +0300 Subject: [PATCH] TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupplier of Container by either direct field usage or a reflection equivalent. If the field is private, the reflection will be used; otherwise, direct access to the field will be used --- .../ImportTestcontainersTests.java | 108 ++++++++++++++++++ .../TestcontainerFieldBeanDefinition.java | 5 +- ...ontainersBeanRegistrationAotProcessor.java | 107 +++++++++++++++++ .../TestcontainersPropertySource.java | 11 ++ .../resources/META-INF/spring/aot.factories | 6 +- 5 files changed, 234 insertions(+), 3 deletions(-) create mode 100644 spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java diff --git a/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java b/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java index c3d0bd43703b..2fd17e133739 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java +++ b/spring-boot-project/spring-boot-testcontainers/src/dockerTest/java/org/springframework/boot/testcontainers/ImportTestcontainersTests.java @@ -18,17 +18,26 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; +import java.util.function.BiConsumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.testcontainers.containers.Container; import org.testcontainers.containers.PostgreSQLContainer; +import org.springframework.aot.test.generate.TestGenerationContext; import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition; import org.springframework.boot.testcontainers.context.ImportTestcontainers; import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable; import org.springframework.boot.testsupport.container.TestImage; +import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.aot.ApplicationContextAotGenerator; +import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.test.tools.CompileWithForkedClassLoader; +import org.springframework.core.test.tools.Compiled; +import org.springframework.core.test.tools.TestCompiler; +import org.springframework.javapoet.ClassName; import org.springframework.test.context.DynamicPropertyRegistry; import org.springframework.test.context.DynamicPropertySource; @@ -43,6 +52,8 @@ @DisabledIfDockerUnavailable class ImportTestcontainersTests { + private final TestGenerationContext generationContext = new TestGenerationContext(); + private AnnotationConfigApplicationContext applicationContext; @AfterEach @@ -122,6 +133,81 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() { .withMessage("@DynamicPropertySource method 'containerProperties' must be static"); } + @Test + @CompileWithForkedClassLoader + void importTestcontainersImportWithoutValueAotContribution() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ImportWithoutValue.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ImportWithoutValue.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersImportWithValueAotContribution() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ImportWithValue.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ContainerDefinitions.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersWithDynamicPropertySourceAotContribution() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ContainerDefinitionsWithDynamicPropertySource.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersWithCustomPostgreSQLContainerAotContribution() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class); + compile((freshContext, compiled) -> { + CustomPostgreSQLContainer container = freshContext.getBean(CustomPostgreSQLContainer.class); + assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container); + }); + } + + @Test + @CompileWithForkedClassLoader + void importTestcontainersWithNotAccessibleContainerAotContribution() { + this.applicationContext = new AnnotationConfigApplicationContext(); + this.applicationContext.register(ImportNotAccessibleContainer.class); + compile((freshContext, compiled) -> { + PostgreSQLContainer container = freshContext.getBean(PostgreSQLContainer.class); + assertThat(container).isSameAs(ImportNotAccessibleContainer.container); + }); + } + + @SuppressWarnings("unchecked") + private void compile(BiConsumer result) { + ClassName className = processAheadOfTime(); + TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> { + GenericApplicationContext freshApplicationContext = new GenericApplicationContext(); + ApplicationContextInitializer initializer = compiled + .getInstance(ApplicationContextInitializer.class, className.toString()); + initializer.initialize(freshApplicationContext); + freshApplicationContext.refresh(); + result.accept(freshApplicationContext, compiled); + }); + } + + private ClassName processAheadOfTime() { + ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext, + this.generationContext); + this.generationContext.writeGeneratedContent(); + return className; + } + @ImportTestcontainers static class ImportWithoutValue { @@ -196,4 +282,26 @@ void containerProperties() { } + @ImportTestcontainers + static class CustomPostgreSQLContainerDefinitions { + + static CustomPostgreSQLContainer container = new CustomPostgreSQLContainer(); + + } + + static class CustomPostgreSQLContainer extends PostgreSQLContainer { + + CustomPostgreSQLContainer() { + super("postgres:14"); + } + + } + + @ImportTestcontainers + static class ImportNotAccessibleContainer { + + private static final PostgreSQLContainer container = TestImage.container(PostgreSQLContainer.class); + + } + } diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java index c5cf32d4b1aa..2f81ad8a7ced 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainerFieldBeanDefinition.java @@ -1,5 +1,5 @@ /* - * Copyright 2012-2023 the original author or authors. + * Copyright 2012-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,9 +38,10 @@ class TestcontainerFieldBeanDefinition extends RootBeanDefinition implements Tes TestcontainerFieldBeanDefinition(Field field, Container container) { this.container = container; this.annotations = MergedAnnotations.from(field); - this.setBeanClass(container.getClass()); + setBeanClass(container.getClass()); setInstanceSupplier(() -> container); setRole(ROLE_INFRASTRUCTURE); + setAttribute(TestcontainerFieldBeanDefinition.class.getName(), field); } @Override diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java new file mode 100644 index 000000000000..570aa07568e8 --- /dev/null +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/context/TestcontainersBeanRegistrationAotProcessor.java @@ -0,0 +1,107 @@ +/* + * Copyright 2012-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.boot.testcontainers.context; + +import java.lang.reflect.Field; + +import javax.lang.model.element.Modifier; + +import org.testcontainers.containers.Container; + +import org.springframework.aot.generate.AccessControl; +import org.springframework.aot.generate.GeneratedMethod; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; +import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; +import org.springframework.beans.factory.aot.BeanRegistrationCode; +import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments; +import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; + +/** + * {@link BeanRegistrationAotProcessor} that replaces InstanceSupplier of + * {@link Container} by either direct field usage or a reflection equivalent. + *

+ * If the field is private, the reflection will be used; otherwise, direct access to the + * field will be used. + * + * @author Dmytro Nosan + */ +class TestcontainersBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { + + @Override + public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { + RootBeanDefinition bd = registeredBean.getMergedBeanDefinition(); + String attributeName = TestcontainerFieldBeanDefinition.class.getName(); + Object field = bd.getAttribute(attributeName); + if (field != null) { + Assert.isInstanceOf(Field.class, field, + "BeanDefinition attribute '" + attributeName + "' value must be a type of '" + Field.class + "'"); + return BeanRegistrationAotContribution.withCustomCodeFragments( + (codeFragments) -> new AotContribution(codeFragments, registeredBean, ((Field) field))); + } + return null; + } + + static class AotContribution extends BeanRegistrationCodeFragmentsDecorator { + + private final RegisteredBean registeredBean; + + private final Field field; + + AotContribution(BeanRegistrationCodeFragments delegate, RegisteredBean registeredBean, Field field) { + super(delegate); + this.registeredBean = registeredBean; + this.field = field; + } + + @Override + public ClassName getTarget(RegisteredBean registeredBean) { + return ClassName.get(this.field.getDeclaringClass()); + } + + @Override + public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext, + BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) { + if (AccessControl.forMember(this.field).isAccessibleFrom(beanRegistrationCode.getClassName())) { + return CodeBlock.of("() -> $T.$L", this.field.getDeclaringClass(), this.field.getName()); + } + generationContext.getRuntimeHints().reflection().registerField(this.field); + GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() + .add("getInstance", (method) -> method.addModifiers(Modifier.PRIVATE, Modifier.STATIC) + .addJavadoc("Get the bean instance for '$L'.", this.registeredBean.getBeanName()) + .returns(this.registeredBean.getBeanClass()) + .addStatement("$T field = $T.findField($T.class, $S)", Field.class, ReflectionUtils.class, + this.field.getDeclaringClass(), this.field.getName()) + .addStatement("$T.notNull(field, $S)", Assert.class, + "Field '" + this.field.getName() + "' is not found") + .addStatement("$T.makeAccessible(field)", ReflectionUtils.class) + .addStatement("$T container = $T.getField(field, null)", Object.class, ReflectionUtils.class) + .addStatement("$T.notNull(container, $S)", Assert.class, + "Container field '" + this.field.getName() + "' must not have a null value") + .addStatement("return ($T) container", this.registeredBean.getBeanClass())); + return generatedMethod.toMethodReference().toCodeBlock(); + } + + } + +} diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java index f1ecfe878c80..d49ef3413bce 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java +++ b/spring-boot-project/spring-boot-testcontainers/src/main/java/org/springframework/boot/testcontainers/properties/TestcontainersPropertySource.java @@ -26,9 +26,11 @@ import org.testcontainers.containers.Container; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.RegisteredBean; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; @@ -166,4 +168,13 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) } + static class TestcontainersEventPublisherBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter { + + @Override + public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) { + return EventPublisherRegistrar.NAME.equals(registeredBean.getBeanName()); + } + + } + } diff --git a/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories b/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories index 5b3d49bd5020..61ff6cf6d122 100644 --- a/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories +++ b/spring-boot-project/spring-boot-testcontainers/src/main/resources/META-INF/spring/aot.factories @@ -1,5 +1,9 @@ org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter=\ -org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter +org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter,\ +org.springframework.boot.testcontainers.properties.TestcontainersPropertySource.TestcontainersEventPublisherBeanRegistrationExcludeFilter org.springframework.aot.hint.RuntimeHintsRegistrar=\ org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory.ContainerConnectionDetailsFactoriesRuntimeHints + +org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\ +org.springframework.boot.testcontainers.context.TestcontainersBeanRegistrationAotProcessor