Fix CopyDefaultValuesFromExpectLowering

Previously it didn't cover some cases:
- non-composable fun with composable parameters
- default value expression that uses other arguments
- default value expression calling another Composable
- expect fun with Type parameter

The new implementation is inspired by Kotlin's ExpectToActualDefaultValueCopier (k/native lowering) + we remove the default values in the relevant expect functions to ensure the kotlin's lowering won't try to copy them again.

Test: added more tests in RunComposableTests

Change-Id: I033230653f076ef6d2900f6177c6649ce20851a2
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt
index 3bfd3f1..4aad0a8 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RunComposableTests.kt
@@ -94,6 +94,264 @@
         }
     }
 
+    @Test
+    fun testExpectWithGetExpectedPropertyInDefaultValueExpression() {
+        runCompose(
+            testFunBody = """
+                ExpectComposable { value ->
+                    results["defaultValue"] = value
+                }
+                ExpectComposable({ expectedProperty + expectedProperty.reversed() }) { value ->
+                    results["anotherValue"] = value
+                }
+            """.trimIndent(),
+            files = mapOf(
+                "Expect.kt" to """
+                    import androidx.compose.runtime.*
+
+                    expect val expectedProperty: String
+
+                    @Composable
+                    expect fun ExpectComposable(
+                        value: () -> String = { expectedProperty },
+                        content: @Composable (v: String) -> Unit
+                    )
+                """.trimIndent(),
+                "Actual.kt" to """
+                    import androidx.compose.runtime.*
+
+                    actual val expectedProperty = "actualExpectedProperty"
+
+                    @Composable
+                    actual fun ExpectComposable(
+                        value: () -> String,
+                        content: @Composable (v: String) -> Unit
+                    ) {
+                        content(value())
+                    }
+                """.trimIndent()
+            )
+        ) { results ->
+            assertEquals("actualExpectedProperty", results["defaultValue"])
+            assertEquals(
+                "actualExpectedProperty" + "actualExpectedProperty".reversed(),
+                results["anotherValue"]
+            )
+        }
+    }
+
+    @Test
+    fun testExpectWithComposableExpressionInDefaultValue() {
+        runCompose(
+            testFunBody = """
+                ExpectComposable { value ->
+                    results["defaultValue"] = value
+                }
+                ExpectComposable("anotherValue") { value ->
+                    results["anotherValue"] = value
+                }
+            """.trimIndent(),
+            files = mapOf(
+                "Expect.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    fun defaultValueComposable(): String {
+                        return "defaultValueComposable"
+                    }
+
+                    @Composable
+                    expect fun ExpectComposable(
+                        value: String = defaultValueComposable(),
+                        content: @Composable (v: String) -> Unit
+                    )
+                """.trimIndent(),
+                "Actual.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    actual fun ExpectComposable(
+                        value: String,
+                        content: @Composable (v: String) -> Unit
+                    ) {
+                        content(value)
+                    }
+                """.trimIndent()
+            )
+        ) { results ->
+            assertEquals("defaultValueComposable", results["defaultValue"])
+            assertEquals("anotherValue", results["anotherValue"])
+        }
+    }
+
+    @Test
+    fun testExpectWithTypedParameter() {
+        runCompose(
+            testFunBody = """
+                ExpectComposable<String>("aeiouy") { value ->
+                    results["defaultValue"] = value
+                }
+                ExpectComposable<String>("aeiouy", { "anotherValue" }) { value ->
+                    results["anotherValue"] = value
+                }
+            """.trimIndent(),
+            files = mapOf(
+                "Expect.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    expect fun <T> ExpectComposable(
+                        value: T,
+                        composeValue: @Composable () -> T = { value },
+                        content: @Composable (T) -> Unit
+                    )
+                """.trimIndent(),
+                "Actual.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    actual fun <T> ExpectComposable(
+                        value: T,
+                        composeValue: @Composable () -> T,
+                        content: @Composable (T) -> Unit
+                    ) {
+                        content(composeValue())
+                    }
+                """.trimIndent()
+            )
+        ) { results ->
+            assertEquals("aeiouy", results["defaultValue"])
+            assertEquals("anotherValue", results["anotherValue"])
+        }
+    }
+
+    @Test
+    fun testExpectWithRememberInDefaultValueExpression() {
+        runCompose(
+            testFunBody = """
+                ExpectComposable { value ->
+                    results["defaultValue"] = value
+                }
+                ExpectComposable(remember { "anotherRememberedValue" }) { value ->
+                    results["anotherValue"] = value
+                }
+            """.trimIndent(),
+            files = mapOf(
+                "Expect.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    expect fun ExpectComposable(
+                        value: String = remember { "rememberedDefaultValue" },
+                        content: @Composable (v: String) -> Unit
+                    )
+                """.trimIndent(),
+                "Actual.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    actual fun ExpectComposable(
+                        value: String,
+                        content: @Composable (v: String) -> Unit
+                    ) {
+                        content(value)
+                    }
+                """.trimIndent()
+            )
+        ) { results ->
+            assertEquals("rememberedDefaultValue", results["defaultValue"])
+            assertEquals("anotherRememberedValue", results["anotherValue"])
+        }
+    }
+
+    @Test
+    fun testExpectWithDefaultValueUsingAnotherArgument() {
+        runCompose(
+            testFunBody = """
+                ExpectComposable("AbccbA") { value ->
+                    results["defaultValue"] = value
+                }
+                ExpectComposable("123", { s -> s + s.reversed() }) { value ->
+                    results["anotherValue"] = value
+                }
+            """.trimIndent(),
+            files = mapOf(
+                "Expect.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    expect fun ExpectComposable(
+                        value: String,
+                        composeText: (String) -> String = { value },
+                        content: @Composable (v: String) -> Unit
+                    )
+                """.trimIndent(),
+                "Actual.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    actual fun ExpectComposable(
+                        value: String,
+                        composeText: (String) -> String,
+                        content: @Composable (v: String) -> Unit
+                    ) {
+                        content(composeText(value))
+                    }
+                """.trimIndent()
+            )
+        ) { results ->
+            assertEquals("AbccbA", results["defaultValue"])
+            assertEquals("123321", results["anotherValue"])
+        }
+    }
+
+    @Test
+    fun testNonComposableFunWithComposableParam() {
+        runCompose(
+            testFunBody = """
+                savedContentLambda = null
+                ExpectFunWithComposableParam { value ->
+                    results["defaultValue"] = value
+                }
+                savedContentLambda!!.invoke()
+
+                savedContentLambda = null
+                ExpectFunWithComposableParam("3.14") { value ->
+                    results["anotherValue"] = value
+                }
+                savedContentLambda!!.invoke()
+            """.trimIndent(),
+            files = mapOf(
+                "Expect.kt" to """
+                    import androidx.compose.runtime.*
+
+                    var savedContentLambda: (@Composable () -> Unit)? = null
+
+                    expect fun ExpectFunWithComposableParam(
+                        value: String = "000",
+                        content: @Composable (v: String) -> Unit
+                    )
+                """.trimIndent(),
+                "Actual.kt" to """
+                    import androidx.compose.runtime.*
+
+                    @Composable
+                    actual fun ExpectFunWithComposableParam(
+                        value: String,
+                        content: @Composable (v: String) -> Unit
+                    ) {
+                        savedContentLambda = {
+                            content(value)
+                        }
+                    }
+                """.trimIndent()
+            )
+        ) { results ->
+            assertEquals("000", results["defaultValue"])
+            assertEquals("3.14", results["anotherValue"])
+        }
+    }
+
     // This method was partially borrowed/copy-pasted from RobolectricComposeTester
     // where some of the code was commented out. Those commented out parts are needed here.
     private fun runCompose(
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt
index bbb5435..4629463 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeIrGenerationExtension.kt
@@ -108,7 +108,7 @@
             metrics
         ).lower(moduleFragment)
 
-        CopyDefaultValuesFromExpectLowering().lower(moduleFragment)
+        CopyDefaultValuesFromExpectLowering(pluginContext).lower(moduleFragment)
 
         val mangler = when {
             pluginContext.platform.isJs() -> JsManglerIr
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt
index 439cfbe..1ce5f87 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/CopyDefaultValuesFromExpectLowering.kt
@@ -16,15 +16,42 @@
 
 package androidx.compose.compiler.plugins.kotlin.lower
 
+import androidx.compose.compiler.plugins.kotlin.ComposeFqNames
 import androidx.compose.compiler.plugins.kotlin.hasComposableAnnotation
-import org.jetbrains.kotlin.descriptors.FunctionDescriptor
+import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
+import org.jetbrains.kotlin.descriptors.MemberDescriptor
 import org.jetbrains.kotlin.ir.IrStatement
 import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
+import org.jetbrains.kotlin.ir.declarations.IrClass
+import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
 import org.jetbrains.kotlin.ir.declarations.IrFunction
 import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
+import org.jetbrains.kotlin.ir.declarations.IrProperty
+import org.jetbrains.kotlin.ir.declarations.IrTypeParameter
+import org.jetbrains.kotlin.ir.declarations.IrValueParameter
+import org.jetbrains.kotlin.ir.expressions.IrExpressionBody
+import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
+import org.jetbrains.kotlin.ir.symbols.IrClassifierSymbol
+import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol
+import org.jetbrains.kotlin.ir.symbols.IrEnumEntrySymbol
+import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
+import org.jetbrains.kotlin.ir.symbols.IrPropertySymbol
+import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
+import org.jetbrains.kotlin.ir.symbols.IrTypeParameterSymbol
+import org.jetbrains.kotlin.ir.symbols.IrValueParameterSymbol
+import org.jetbrains.kotlin.ir.symbols.IrValueSymbol
+import org.jetbrains.kotlin.ir.util.DeepCopyIrTreeWithSymbols
+import org.jetbrains.kotlin.ir.util.DeepCopySymbolRemapper
+import org.jetbrains.kotlin.ir.util.DeepCopyTypeRemapper
+import org.jetbrains.kotlin.ir.util.hasAnnotation
+import org.jetbrains.kotlin.ir.util.patchDeclarationParents
+import org.jetbrains.kotlin.ir.util.referenceFunction
 import org.jetbrains.kotlin.ir.visitors.IrElementTransformerVoid
+import org.jetbrains.kotlin.ir.visitors.acceptVoid
 import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
-import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleExpectsForActual
+import org.jetbrains.kotlin.resolve.descriptorUtil.module
+import org.jetbrains.kotlin.resolve.descriptorUtil.propertyIfAccessor
+import org.jetbrains.kotlin.resolve.multiplatform.findCompatibleActualsForExpected
 
 /**
  * [ComposableFunctionBodyTransformer] relies on presence of default values in
@@ -37,55 +64,207 @@
  * This lowering needs to run before [ComposableFunctionBodyTransformer] and
  * before [ComposerParamTransformer].
  *
- * Fixes https://github.com/JetBrains/compose-jb/issues/1407
+ * Fixes:
+ * https://github.com/JetBrains/compose-jb/issues/1407
+ * https://github.com/JetBrains/compose-multiplatform/issues/2816
+ * https://github.com/JetBrains/compose-multiplatform/issues/2806
+ *
+ * This implementation is borrowed from Kotlin's ExpectToActualDefaultValueCopier.
+ * Currently, it heavily relies on descriptors to find expect for actuals or vice versa:
+ * findCompatibleActualsForExpected.
+ * Unlike ExpectToActualDefaultValueCopier, this lowering performs its transformations
+ * only for functions marked with @Composable annotation or
+ * for functions with @Composable lambdas in parameters.
+ *
+ * TODO(karpovich): When adding support for FIR we'll need to use different API.
+ * Likely: fun FirBasedSymbol<*>.getSingleCompatibleExpectForActualOrNull(): FirBasedSymbol<*>?
  */
 @OptIn(ObsoleteDescriptorBasedAPI::class)
-class CopyDefaultValuesFromExpectLowering : ModuleLoweringPass {
+class CopyDefaultValuesFromExpectLowering(
+    pluginContext: IrPluginContext
+) : ModuleLoweringPass, IrElementTransformerVoid() {
+
+    private val symbolTable = pluginContext.symbolTable
+
+    private fun isApplicable(declaration: IrFunction): Boolean {
+        return declaration.hasComposableAnnotation() ||
+            declaration.valueParameters.any {
+                it.type.hasAnnotation(ComposeFqNames.Composable)
+            }
+    }
+
+    override fun visitFunction(declaration: IrFunction): IrStatement {
+        val original = super.visitFunction(declaration) as? IrFunction ?: return declaration
+
+        if (!original.isExpect || !isApplicable(original)) {
+            return original
+        }
+
+        val actualForExpected = original.findActualForExpected()
+
+        original.valueParameters.forEachIndexed { index, expectValueParameter ->
+            val actualValueParameter = actualForExpected.valueParameters[index]
+            val expectDefaultValue = expectValueParameter.defaultValue
+            if (expectDefaultValue != null) {
+                actualValueParameter.defaultValue = expectDefaultValue
+                    .remapExpectValueSymbols()
+                    .patchDeclarationParents(actualForExpected)
+
+                // Remove a default value in the expect fun in order to prevent
+                // Kotlin expect/actual-related lowerings trying to copy the default values again
+                expectValueParameter.defaultValue = null
+            }
+        }
+        return original
+    }
 
     override fun lower(module: IrModuleFragment) {
-        // it uses FunctionDescriptor since current API (findCompatibleExpectedForActual)
-        // can return only a descriptor
-        val expectComposables = mutableMapOf<FunctionDescriptor, IrFunction>()
+        module.transformChildrenVoid(this)
+    }
 
-        // first pass to find expect functions with default values
-        module.transformChildrenVoid(object : IrElementTransformerVoid() {
-            override fun visitFunction(declaration: IrFunction): IrStatement {
-                if (declaration.isExpect && declaration.hasComposableAnnotation()) {
-                    val hasDefaultValues = declaration.valueParameters.any {
-                        it.defaultValue != null
+    private inline fun <reified T : IrFunction> T.findActualForExpected(): T =
+        symbolTable.referenceFunction(descriptor.findActualForExpect()).owner as T
+
+    private fun IrProperty.findActualForExpected(): IrProperty =
+        symbolTable.referenceProperty(descriptor.findActualForExpect()).owner
+
+    private fun IrClass.findActualForExpected(): IrClass =
+        symbolTable.referenceClass(descriptor.findActualForExpect()).owner
+
+    private fun IrEnumEntry.findActualForExpected(): IrEnumEntry =
+        symbolTable.referenceEnumEntry(descriptor.findActualForExpect()).owner
+
+    private inline fun <reified T : MemberDescriptor> T.findActualForExpect(): T {
+        if (!this.isExpect) error(this)
+        return (findCompatibleActualsForExpected(module).singleOrNull() ?: error(this)) as T
+    }
+
+    private fun IrExpressionBody.remapExpectValueSymbols(): IrExpressionBody {
+        class SymbolRemapper : DeepCopySymbolRemapper() {
+            override fun getReferencedClass(symbol: IrClassSymbol) =
+                if (symbol.descriptor.isExpect)
+                    symbol.owner.findActualForExpected().symbol
+                else super.getReferencedClass(symbol)
+
+            override fun getReferencedClassOrNull(symbol: IrClassSymbol?) =
+                symbol?.let { getReferencedClass(it) }
+
+            override fun getReferencedClassifier(symbol: IrClassifierSymbol): IrClassifierSymbol =
+                when (symbol) {
+                    is IrClassSymbol -> getReferencedClass(symbol)
+                    is IrTypeParameterSymbol -> remapExpectTypeParameter(symbol).symbol
+                    else -> error("Unexpected symbol $symbol ${symbol.descriptor}")
+                }
+
+            override fun getReferencedConstructor(symbol: IrConstructorSymbol) =
+                if (symbol.descriptor.isExpect)
+                    symbol.owner.findActualForExpected().symbol
+                else super.getReferencedConstructor(symbol)
+
+            override fun getReferencedFunction(symbol: IrFunctionSymbol): IrFunctionSymbol =
+                when (symbol) {
+                    is IrSimpleFunctionSymbol -> getReferencedSimpleFunction(symbol)
+                    is IrConstructorSymbol -> getReferencedConstructor(symbol)
+                    else -> error("Unexpected symbol $symbol ${symbol.descriptor}")
+                }
+
+            override fun getReferencedSimpleFunction(symbol: IrSimpleFunctionSymbol) = when {
+                symbol.descriptor.isExpect -> symbol.owner.findActualForExpected().symbol
+
+                symbol.descriptor.propertyIfAccessor.isExpect -> {
+                    val property = symbol.owner.correspondingPropertySymbol!!.owner
+                    val actualPropertyDescriptor = property.descriptor.findActualForExpect()
+                    val accessorDescriptor = when (symbol.owner) {
+                        property.getter -> actualPropertyDescriptor.getter!!
+                        property.setter -> actualPropertyDescriptor.setter!!
+                        else -> error("Unexpected accessor of $symbol ${symbol.descriptor}")
                     }
-                    if (hasDefaultValues) {
-                        expectComposables[declaration.descriptor] = declaration
+                    symbolTable.referenceFunction(accessorDescriptor) as IrSimpleFunctionSymbol
+                }
+
+                else -> super.getReferencedSimpleFunction(symbol)
+            }
+
+            override fun getReferencedProperty(symbol: IrPropertySymbol) =
+                if (symbol.descriptor.isExpect)
+                    symbol.owner.findActualForExpected().symbol
+                else
+                    super.getReferencedProperty(symbol)
+
+            override fun getReferencedEnumEntry(symbol: IrEnumEntrySymbol): IrEnumEntrySymbol =
+                if (symbol.descriptor.isExpect)
+                    symbol.owner.findActualForExpected().symbol
+                else
+                    super.getReferencedEnumEntry(symbol)
+
+            override fun getReferencedValue(symbol: IrValueSymbol) =
+                remapExpectValue(symbol)?.symbol ?: super.getReferencedValue(symbol)
+        }
+
+        val symbolRemapper = SymbolRemapper()
+        acceptVoid(symbolRemapper)
+
+        return transform(
+            transformer = DeepCopyIrTreeWithSymbols(
+                symbolRemapper, DeepCopyTypeRemapper(symbolRemapper)
+            ),
+            data = null
+        )
+    }
+
+    private fun remapExpectTypeParameter(symbol: IrTypeParameterSymbol): IrTypeParameter {
+        val parameter = symbol.owner
+        val parent = parameter.parent
+
+        return when (parent) {
+            is IrClass ->
+                if (!parent.descriptor.isExpect)
+                    parameter
+                else parent.findActualForExpected().typeParameters[parameter.index]
+
+            is IrFunction ->
+                if (!parent.descriptor.isExpect)
+                    parameter
+                else parent.findActualForExpected().typeParameters[parameter.index]
+
+            else -> error(parent)
+        }
+    }
+
+    private fun remapExpectValue(symbol: IrValueSymbol): IrValueParameter? {
+        if (symbol !is IrValueParameterSymbol) {
+            return null
+        }
+
+        val parameter = symbol.owner
+        val parent = parameter.parent
+
+        return when (parent) {
+            is IrClass ->
+                if (!parent.descriptor.isExpect)
+                    null
+                else {
+                    assert(parameter == parent.thisReceiver)
+                    parent.findActualForExpected().thisReceiver!!
+                }
+
+            is IrFunction ->
+                if (!parent.descriptor.isExpect)
+                    null
+                else when (parameter) {
+                    parent.dispatchReceiverParameter ->
+                        parent.findActualForExpected().dispatchReceiverParameter!!
+
+                    parent.extensionReceiverParameter ->
+                        parent.findActualForExpected().extensionReceiverParameter!!
+
+                    else -> {
+                        assert(parent.valueParameters[parameter.index] == parameter)
+                        parent.findActualForExpected().valueParameters[parameter.index]
                     }
                 }
-                return super.visitFunction(declaration)
-            }
-        })
 
-        // second pass to set corresponding default values
-        module.transformChildrenVoid(object : IrElementTransformerVoid() {
-            override fun visitFunction(declaration: IrFunction): IrStatement {
-                if (declaration.descriptor.isActual && declaration.hasComposableAnnotation()) {
-                    val compatibleExpects = declaration.descriptor.findCompatibleExpectsForActual {
-                        module.descriptor == it
-                    }
-                    if (compatibleExpects.isNotEmpty()) {
-                        val expectFun = compatibleExpects.firstOrNull {
-                            it in expectComposables
-                        }?.let {
-                            expectComposables[it]
-                        }
-
-                        if (expectFun != null) {
-                            declaration.valueParameters.forEachIndexed { index, it ->
-                                it.defaultValue =
-                                    it.defaultValue ?: expectFun.valueParameters[index].defaultValue
-                            }
-                        }
-                    }
-                }
-                return super.visitFunction(declaration)
-            }
-        })
+            else -> error(parent)
+        }
     }
 }
\ No newline at end of file