Fix CopyDefaultValuesFromExpectLowering
Previously it didn't cover some cases:
- non-composable fun with composable parameters
- default value expression that uses other arguments
- default value expression calling another Composable
- expect fun with Type parameter
The new implementation is inspired by Kotlin's ExpectToActualDefaultValueCopier (k/native lowering) + we remove the default values in the relevant expect functions to ensure the kotlin's lowering won't try to copy them again.
Test: added more tests in RunComposableTests
Change-Id: I033230653f076ef6d2900f6177c6649ce20851a2
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt
index 3bfd3f1..4aad0a8 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt
@@ -94,6 +94,264 @@
}
}
+ @Test
+ fun testExpectWithGetExpectedPropertyInDefaultValueExpression() {
+ runCompose(
+ testFunBody = """
+ ExpectComposable { value ->
+ results["defaultValue"] = value
+ }
+ ExpectComposable({ expectedProperty + expectedProperty.reversed() }) { value ->
+ results["anotherValue"] = value
+ }
+ """.trimIndent(),
+ files = mapOf(
+ "Expect.kt" to """
+ import androidx.compose.runtime.*
+
+ expect val expectedProperty: String
+
+ @Composable
+ expect fun ExpectComposable(
+ value: () -> String = { expectedProperty },
+ content: @Composable (v: String) -> Unit
+ )
+ """.trimIndent(),
+ "Actual.kt" to """
+ import androidx.compose.runtime.*
+
+ actual val expectedProperty = "actualExpectedProperty"
+
+ @Composable
+ actual fun ExpectComposable(
+ value: () -> String,
+ content: @Composable (v: String) -> Unit
+ ) {
+ content(value())
+ }
+ """.trimIndent()
+ )
+ ) { results ->
+ assertEquals("actualExpectedProperty", results["defaultValue"])
+ assertEquals(
+ "actualExpectedProperty" + "actualExpectedProperty".reversed(),
+ results["anotherValue"]
+ )
+ }
+ }
+
+ @Test
+ fun testExpectWithComposableExpressionInDefaultValue() {
+ runCompose(
+ testFunBody = """
+ ExpectComposable { value ->
+ results["defaultValue"] = value
+ }
+ ExpectComposable("anotherValue") { value ->
+ results["anotherValue"] = value
+ }
+ """.trimIndent(),
+ files = mapOf(
+ "Expect.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ fun defaultValueComposable(): String {
+ return "defaultValueComposable"
+ }
+
+ @Composable
+ expect fun ExpectComposable(
+ value: String = defaultValueComposable(),
+ content: @Composable (v: String) -> Unit
+ )
+ """.trimIndent(),
+ "Actual.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ actual fun ExpectComposable(
+ value: String,
+ content: @Composable (v: String) -> Unit
+ ) {
+ content(value)
+ }
+ """.trimIndent()
+ )
+ ) { results ->
+ assertEquals("defaultValueComposable", results["defaultValue"])
+ assertEquals("anotherValue", results["anotherValue"])
+ }
+ }
+
+ @Test
+ fun testExpectWithTypedParameter() {
+ runCompose(
+ testFunBody = """
+ ExpectComposable<String>("aeiouy") { value ->
+ results["defaultValue"] = value
+ }
+ ExpectComposable<String>("aeiouy", { "anotherValue" }) { value ->
+ results["anotherValue"] = value
+ }
+ """.trimIndent(),
+ files = mapOf(
+ "Expect.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ expect fun <T> ExpectComposable(
+ value: T,
+ composeValue: @Composable () -> T = { value },
+ content: @Composable (T) -> Unit
+ )
+ """.trimIndent(),
+ "Actual.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ actual fun <T> ExpectComposable(
+ value: T,
+ composeValue: @Composable () -> T,
+ content: @Composable (T) -> Unit
+ ) {
+ content(composeValue())
+ }
+ """.trimIndent()
+ )
+ ) { results ->
+ assertEquals("aeiouy", results["defaultValue"])
+ assertEquals("anotherValue", results["anotherValue"])
+ }
+ }
+
+ @Test
+ fun testExpectWithRememberInDefaultValueExpression() {
+ runCompose(
+ testFunBody = """
+ ExpectComposable { value ->
+ results["defaultValue"] = value
+ }
+ ExpectComposable(remember { "anotherRememberedValue" }) { value ->
+ results["anotherValue"] = value
+ }
+ """.trimIndent(),
+ files = mapOf(
+ "Expect.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ expect fun ExpectComposable(
+ value: String = remember { "rememberedDefaultValue" },
+ content: @Composable (v: String) -> Unit
+ )
+ """.trimIndent(),
+ "Actual.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ actual fun ExpectComposable(
+ value: String,
+ content: @Composable (v: String) -> Unit
+ ) {
+ content(value)
+ }
+ """.trimIndent()
+ )
+ ) { results ->
+ assertEquals("rememberedDefaultValue", results["defaultValue"])
+ assertEquals("anotherRememberedValue", results["anotherValue"])
+ }
+ }
+
+ @Test
+ fun testExpectWithDefaultValueUsingAnotherArgument() {
+ runCompose(
+ testFunBody = """
+ ExpectComposable("AbccbA") { value ->
+ results["defaultValue"] = value
+ }
+ ExpectComposable("123", { s -> s + s.reversed() }) { value ->
+ results["anotherValue"] = value
+ }
+ """.trimIndent(),
+ files = mapOf(
+ "Expect.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ expect fun ExpectComposable(
+ value: String,
+ composeText: (String) -> String = { value },
+ content: @Composable (v: String) -> Unit
+ )
+ """.trimIndent(),
+ "Actual.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ actual fun ExpectComposable(
+ value: String,
+ composeText: (String) -> String,
+ content: @Composable (v: String) -> Unit
+ ) {
+ content(composeText(value))
+ }
+ """.trimIndent()
+ )
+ ) { results ->
+ assertEquals("AbccbA", results["defaultValue"])
+ assertEquals("123321", results["anotherValue"])
+ }
+ }
+
+ @Test
+ fun testNonComposableFunWithComposableParam() {
+ runCompose(
+ testFunBody = """
+ savedContentLambda = null
+ ExpectFunWithComposableParam { value ->
+ results["defaultValue"] = value
+ }
+ savedContentLambda!!.invoke()
+
+ savedContentLambda = null
+ ExpectFunWithComposableParam("3.14") { value ->
+ results["anotherValue"] = value
+ }
+ savedContentLambda!!.invoke()
+ """.trimIndent(),
+ files = mapOf(
+ "Expect.kt" to """
+ import androidx.compose.runtime.*
+
+ var savedContentLambda: (@Composable () -> Unit)? = null
+
+ expect fun ExpectFunWithComposableParam(
+ value: String = "000",
+ content: @Composable (v: String) -> Unit
+ )
+ """.trimIndent(),
+ "Actual.kt" to """
+ import androidx.compose.runtime.*
+
+ @Composable
+ actual fun ExpectFunWithComposableParam(
+ value: String,
+ content: @Composable (v: String) -> Unit
+ ) {
+ savedContentLambda = {
+ content(value)
+ }
+ }
+ """.trimIndent()
+ )
+ ) { results ->
+ assertEquals("000", results["defaultValue"])
+ assertEquals("3.14", results["anotherValue"])
+ }
+ }
+
// This method was partially borrowed/copy-pasted from RobolectricComposeTester
// where some of the code was commented out. Those commented out parts are needed here.
private fun runCompose(
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt
index bbb5435..4629463 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt
@@ -108,7 +108,7 @@
metrics
).lower(moduleFragment)
- CopyDefaultValuesFromExpectLowering().lower(moduleFragment)
+ CopyDefaultValuesFromExpectLowering(pluginContext).lower(moduleFragment)
val mangler = when {
pluginContext.platform.isJs() -> JsManglerIr
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt
index 439cfbe..1ce5f87 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt
@@ -16,15 +16,42 @@
package androidx.compose.compiler.plugins.kotlin.lower
+import androidx.compose.compiler.plugins.kotlin.ComposeFqNames
import androidx.compose.compiler.plugins.kotlin.hasComposableAnnotation
-import org.jetbrains.kotlin.descriptors.FunctionDescriptor
+import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
+import org.jetbrains.kotlin.descriptors.MemberDescriptor
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
+import org.jetbrains.kotlin.ir.declarations.IrClass
+import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrFunction
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
+import org.jetbrains.kotlin.ir.declarations.IrProperty
+import org.jetbrains.kotlin.ir.declarations.IrTypeParameter
+import org.jetbrains.kotlin.ir.declarations.IrValueParameter
+import org.jetbrains.kotlin.ir.expressions.IrExpressionBody
+import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
+import org.jetbrains.kotlin.ir.symbols.IrClassifierSymbol
+import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
+import org.jetbrains.kotlin.ir.symbols.IrEnumEntrySymbol
+import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
+import org.jetbrains.kotlin.ir.symbols.IrPropertySymbol
+import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
+import org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
+import org.jetbrains.kotlin.ir.symbols.IrValueParameterSymbol
+import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
+import org.jetbrains.kotlin.ir.util.DeepCopyIrTreeWithSymbols
+import org.jetbrains.kotlin.ir.util.DeepCopySymbolRemapper
+import org.jetbrains.kotlin.ir.util.DeepCopyTypeRemapper
+import org.jetbrains.kotlin.ir.util.hasAnnotation
+import org.jetbrains.kotlin.ir.util.patchDeclarationParents
+import org.jetbrains.kotlin.ir.util.referenceFunction
import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
+import org.jetbrains.kotlin.ir.visitors.acceptVoid
import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
-import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleExpectsForActual
+import org.jetbrains.kotlin.resolve.descriptorUtil.module
+import org.jetbrains.kotlin.resolve.descriptorUtil.propertyIfAccessor
+import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleActualsForExpected
/**
* [ComposableFunctionBodyTransformer] relies on presence of default values in
@@ -37,55 +64,207 @@
* This lowering needs to run before [ComposableFunctionBodyTransformer] and
* before [ComposerParamTransformer].
*
- * Fixes https://github.com/JetBrains/compose-jb/issues/1407
+ * Fixes:
+ * https://github.com/JetBrains/compose-jb/issues/1407
+ * https://github.com/JetBrains/compose-multiplatform/issues/2816
+ * https://github.com/JetBrains/compose-multiplatform/issues/2806
+ *
+ * This implementation is borrowed from Kotlin's ExpectToActualDefaultValueCopier.
+ * Currently, it heavily relies on descriptors to find expect for actuals or vice versa:
+ * findCompatibleActualsForExpected.
+ * Unlike ExpectToActualDefaultValueCopier, this lowering performs its transformations
+ * only for functions marked with @Composable annotation or
+ * for functions with @Composable lambdas in parameters.
+ *
+ * TODO(karpovich): When adding support for FIR we'll need to use different API.
+ * Likely: fun FirBasedSymbol<*>.getSingleCompatibleExpectForActualOrNull(): FirBasedSymbol<*>?
*/
@OptIn(ObsoleteDescriptorBasedAPI::class)
-class CopyDefaultValuesFromExpectLowering : ModuleLoweringPass {
+class CopyDefaultValuesFromExpectLowering(
+ pluginContext: IrPluginContext
+) : ModuleLoweringPass, IrElementTransformerVoid() {
+
+ private val symbolTable = pluginContext.symbolTable
+
+ private fun isApplicable(declaration: IrFunction): Boolean {
+ return declaration.hasComposableAnnotation() ||
+ declaration.valueParameters.any {
+ it.type.hasAnnotation(ComposeFqNames.Composable)
+ }
+ }
+
+ override fun visitFunction(declaration: IrFunction): IrStatement {
+ val original = super.visitFunction(declaration) as? IrFunction ?: return declaration
+
+ if (!original.isExpect || !isApplicable(original)) {
+ return original
+ }
+
+ val actualForExpected = original.findActualForExpected()
+
+ original.valueParameters.forEachIndexed { index, expectValueParameter ->
+ val actualValueParameter = actualForExpected.valueParameters[index]
+ val expectDefaultValue = expectValueParameter.defaultValue
+ if (expectDefaultValue != null) {
+ actualValueParameter.defaultValue = expectDefaultValue
+ .remapExpectValueSymbols()
+ .patchDeclarationParents(actualForExpected)
+
+ // Remove a default value in the expect fun in order to prevent
+ // Kotlin expect/actual-related lowerings trying to copy the default values again
+ expectValueParameter.defaultValue = null
+ }
+ }
+ return original
+ }
override fun lower(module: IrModuleFragment) {
- // it uses FunctionDescriptor since current API (findCompatibleExpectedForActual)
- // can return only a descriptor
- val expectComposables = mutableMapOf<FunctionDescriptor, IrFunction>()
+ module.transformChildrenVoid(this)
+ }
- // first pass to find expect functions with default values
- module.transformChildrenVoid(object : IrElementTransformerVoid() {
- override fun visitFunction(declaration: IrFunction): IrStatement {
- if (declaration.isExpect && declaration.hasComposableAnnotation()) {
- val hasDefaultValues = declaration.valueParameters.any {
- it.defaultValue != null
+ private inline fun <reified T : IrFunction> T.findActualForExpected(): T =
+ symbolTable.referenceFunction(descriptor.findActualForExpect()).owner as T
+
+ private fun IrProperty.findActualForExpected(): IrProperty =
+ symbolTable.referenceProperty(descriptor.findActualForExpect()).owner
+
+ private fun IrClass.findActualForExpected(): IrClass =
+ symbolTable.referenceClass(descriptor.findActualForExpect()).owner
+
+ private fun IrEnumEntry.findActualForExpected(): IrEnumEntry =
+ symbolTable.referenceEnumEntry(descriptor.findActualForExpect()).owner
+
+ private inline fun <reified T : MemberDescriptor> T.findActualForExpect(): T {
+ if (!this.isExpect) error(this)
+ return (findCompatibleActualsForExpected(module).singleOrNull() ?: error(this)) as T
+ }
+
+ private fun IrExpressionBody.remapExpectValueSymbols(): IrExpressionBody {
+ class SymbolRemapper : DeepCopySymbolRemapper() {
+ override fun getReferencedClass(symbol: IrClassSymbol) =
+ if (symbol.descriptor.isExpect)
+ symbol.owner.findActualForExpected().symbol
+ else super.getReferencedClass(symbol)
+
+ override fun getReferencedClassOrNull(symbol: IrClassSymbol?) =
+ symbol?.let { getReferencedClass(it) }
+
+ override fun getReferencedClassifier(symbol: IrClassifierSymbol): IrClassifierSymbol =
+ when (symbol) {
+ is IrClassSymbol -> getReferencedClass(symbol)
+ is IrTypeParameterSymbol -> remapExpectTypeParameter(symbol).symbol
+ else -> error("Unexpected symbol $symbol ${symbol.descriptor}")
+ }
+
+ override fun getReferencedConstructor(symbol: IrConstructorSymbol) =
+ if (symbol.descriptor.isExpect)
+ symbol.owner.findActualForExpected().symbol
+ else super.getReferencedConstructor(symbol)
+
+ override fun getReferencedFunction(symbol: IrFunctionSymbol): IrFunctionSymbol =
+ when (symbol) {
+ is IrSimpleFunctionSymbol -> getReferencedSimpleFunction(symbol)
+ is IrConstructorSymbol -> getReferencedConstructor(symbol)
+ else -> error("Unexpected symbol $symbol ${symbol.descriptor}")
+ }
+
+ override fun getReferencedSimpleFunction(symbol: IrSimpleFunctionSymbol) = when {
+ symbol.descriptor.isExpect -> symbol.owner.findActualForExpected().symbol
+
+ symbol.descriptor.propertyIfAccessor.isExpect -> {
+ val property = symbol.owner.correspondingPropertySymbol!!.owner
+ val actualPropertyDescriptor = property.descriptor.findActualForExpect()
+ val accessorDescriptor = when (symbol.owner) {
+ property.getter -> actualPropertyDescriptor.getter!!
+ property.setter -> actualPropertyDescriptor.setter!!
+ else -> error("Unexpected accessor of $symbol ${symbol.descriptor}")
}
- if (hasDefaultValues) {
- expectComposables[declaration.descriptor] = declaration
+ symbolTable.referenceFunction(accessorDescriptor) as IrSimpleFunctionSymbol
+ }
+
+ else -> super.getReferencedSimpleFunction(symbol)
+ }
+
+ override fun getReferencedProperty(symbol: IrPropertySymbol) =
+ if (symbol.descriptor.isExpect)
+ symbol.owner.findActualForExpected().symbol
+ else
+ super.getReferencedProperty(symbol)
+
+ override fun getReferencedEnumEntry(symbol: IrEnumEntrySymbol): IrEnumEntrySymbol =
+ if (symbol.descriptor.isExpect)
+ symbol.owner.findActualForExpected().symbol
+ else
+ super.getReferencedEnumEntry(symbol)
+
+ override fun getReferencedValue(symbol: IrValueSymbol) =
+ remapExpectValue(symbol)?.symbol ?: super.getReferencedValue(symbol)
+ }
+
+ val symbolRemapper = SymbolRemapper()
+ acceptVoid(symbolRemapper)
+
+ return transform(
+ transformer = DeepCopyIrTreeWithSymbols(
+ symbolRemapper, DeepCopyTypeRemapper(symbolRemapper)
+ ),
+ data = null
+ )
+ }
+
+ private fun remapExpectTypeParameter(symbol: IrTypeParameterSymbol): IrTypeParameter {
+ val parameter = symbol.owner
+ val parent = parameter.parent
+
+ return when (parent) {
+ is IrClass ->
+ if (!parent.descriptor.isExpect)
+ parameter
+ else parent.findActualForExpected().typeParameters[parameter.index]
+
+ is IrFunction ->
+ if (!parent.descriptor.isExpect)
+ parameter
+ else parent.findActualForExpected().typeParameters[parameter.index]
+
+ else -> error(parent)
+ }
+ }
+
+ private fun remapExpectValue(symbol: IrValueSymbol): IrValueParameter? {
+ if (symbol !is IrValueParameterSymbol) {
+ return null
+ }
+
+ val parameter = symbol.owner
+ val parent = parameter.parent
+
+ return when (parent) {
+ is IrClass ->
+ if (!parent.descriptor.isExpect)
+ null
+ else {
+ assert(parameter == parent.thisReceiver)
+ parent.findActualForExpected().thisReceiver!!
+ }
+
+ is IrFunction ->
+ if (!parent.descriptor.isExpect)
+ null
+ else when (parameter) {
+ parent.dispatchReceiverParameter ->
+ parent.findActualForExpected().dispatchReceiverParameter!!
+
+ parent.extensionReceiverParameter ->
+ parent.findActualForExpected().extensionReceiverParameter!!
+
+ else -> {
+ assert(parent.valueParameters[parameter.index] == parameter)
+ parent.findActualForExpected().valueParameters[parameter.index]
}
}
- return super.visitFunction(declaration)
- }
- })
- // second pass to set corresponding default values
- module.transformChildrenVoid(object : IrElementTransformerVoid() {
- override fun visitFunction(declaration: IrFunction): IrStatement {
- if (declaration.descriptor.isActual && declaration.hasComposableAnnotation()) {
- val compatibleExpects = declaration.descriptor.findCompatibleExpectsForActual {
- module.descriptor == it
- }
- if (compatibleExpects.isNotEmpty()) {
- val expectFun = compatibleExpects.firstOrNull {
- it in expectComposables
- }?.let {
- expectComposables[it]
- }
-
- if (expectFun != null) {
- declaration.valueParameters.forEachIndexed { index, it ->
- it.defaultValue =
- it.defaultValue ?: expectFun.valueParameters[index].defaultValue
- }
- }
- }
- }
- return super.visitFunction(declaration)
- }
- })
+ else -> error(parent)
+ }
}
}
\ No newline at end of file