Skip to content

Implement type imports and exports #7330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Reuse the type-sorting code from ModuleUtils in wasm-merge
  • Loading branch information
vouillon committed Feb 27, 2025
commit e3168071200e35abb2c8bc025c977aea2c0825e8
13 changes: 10 additions & 3 deletions src/ir/module-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,9 @@ std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
return types;
}

IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
auto counts = collectHeapTypeInfo(wasm, TypeInclusion::BinaryTypes);

IndexedHeapTypes sortHeapTypes(Module& wasm,
InsertOrderedMap<HeapType, HeapTypeInfo>& counts,
std::function<HeapType(HeapType)> map) {
// Collect the rec groups.
std::unordered_map<RecGroup, size_t> groupIndices;
std::vector<RecGroup> groups;
Expand All @@ -766,6 +766,7 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
for (size_t i = 0; i < groups.size(); ++i) {
for (auto type : groups[i]) {
for (auto child : type.getReferencedHeapTypes()) {
child = map(child);
if (child.isBasic()) {
continue;
}
Expand Down Expand Up @@ -862,4 +863,10 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
return indexedTypes;
}

IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
auto counts = collectHeapTypeInfo(wasm, TypeInclusion::BinaryTypes);
return sortHeapTypes(
wasm, counts, [](HeapType type) -> HeapType { return type; });
}

} // namespace wasm::ModuleUtils
6 changes: 6 additions & 0 deletions src/ir/module-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,12 @@ struct IndexedHeapTypes {
std::unordered_map<HeapType, Index> indices;
};

// Orders the types to be valid (after renaming by the map function)
// and sorts the types by frequency of use to minimize code size.
IndexedHeapTypes sortHeapTypes(Module& wasm,
InsertOrderedMap<HeapType, HeapTypeInfo>& counts,
std::function<HeapType(HeapType)> map);

// Similar to `collectHeapTypes`, but provides fast lookup of the index for each
// type as well. Also orders the types to be valid and sorts the types by
// frequency of use to minimize code size.
Expand Down
75 changes: 13 additions & 62 deletions src/tools/wasm-merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,54 +440,6 @@ void checkLimit(bool& valid, const char* kind, T* export_, T* import) {
}
}

// Sort heap types to put children (rewritten by map) before heap type
std::vector<HeapType> sortHeapTypes(std::vector<HeapType>& types,
std::function<HeapType(HeapType)> map) {
// Collect the rec groups.
std::unordered_map<RecGroup, size_t> groupIndices;
std::vector<RecGroup> groups;
for (auto& type : types) {
auto group = type.getRecGroup();
if (groupIndices.insert({group, groups.size()}).second) {
groups.push_back(group);
}
}

// Collect the reverse dependencies of each group.
std::vector<std::unordered_set<size_t>> depSets(groups.size());
for (size_t i = 0; i < groups.size(); ++i) {
for (auto type : groups[i]) {
for (auto child : type.getReferencedHeapTypes()) {
child = map(child);
if (child.isBasic()) {
continue;
}
auto childGroup = child.getRecGroup();
if (childGroup == groups[i]) {
continue;
}
depSets[groupIndices.at(childGroup)].insert(i);
}
}
}
TopologicalSort::Graph deps;
deps.reserve(groups.size());
for (size_t i = 0; i < groups.size(); ++i) {
deps.emplace_back(depSets[i].begin(), depSets[i].end());
}

auto sorted = TopologicalSort::sort(deps);

std::vector<HeapType> sortedTypes;
sortedTypes.reserve(types.size());
for (auto groupIndex : sorted) {
for (auto type : groups[groupIndex]) {
sortedTypes.push_back(type);
}
}
return sortedTypes;
}

// Find pairs of matching type imports and type exports, and make uses
// of the import refer to the exported item (which has been merged
// into the module).
Expand All @@ -507,14 +459,15 @@ void fuseTypeImportsAndTypeExports() {
moduleTypeExportMap[exportInfo.moduleName][exportInfo.baseName] =
ex->heaptype;
}

auto heapTypeInfo = ModuleUtils::collectHeapTypeInfo(merged);

// For each type import, see whether it has a corresponding
// export, check that the imported type is a subtype of the import
// bound. Record the corresponding mapping.
bool valid = true;
std::unordered_map<HeapType, HeapType> typeUpdates;
std::vector<HeapType> heapTypes = ModuleUtils::collectHeapTypes(merged);

for (HeapType& heapType : heapTypes) {
for (auto& [heapType, _] : heapTypeInfo) {
if (heapType.isImport()) {
Import import = heapType.getImport();
if (auto newType = moduleTypeExportMap[import.module].find(import.base);
Expand Down Expand Up @@ -574,26 +527,24 @@ void fuseTypeImportsAndTypeExports() {
};

// Sort heap types so that children are before
heapTypes = sortHeapTypes(heapTypes, initialMap);
std::unordered_map<HeapType, size_t> typeIndices;
TypeBuilder typeBuilder(heapTypes.size());
for (size_t i = 0; i < heapTypes.size(); i++) {
typeIndices[heapTypes[i]] = i;
}
ModuleUtils::IndexedHeapTypes indexedTypes =
ModuleUtils::sortHeapTypes(merged, heapTypeInfo, initialMap);

TypeBuilder typeBuilder(indexedTypes.types.size());

// Map from a heap type to the corresponding temporary type
auto map = [&](HeapType type) -> HeapType {
type = initialMap(type);
if (type.isBasic()) {
return type;
}
return typeBuilder[typeIndices[type]];
return typeBuilder[indexedTypes.indices[type]];
};

// Build new types
std::optional<RecGroup> lastGroup = std::nullopt;
for (size_t i = 0; i < heapTypes.size(); i++) {
HeapType heapType = heapTypes[i];
for (size_t i = 0; i < indexedTypes.types.size(); i++) {
HeapType heapType = indexedTypes.types[i];
typeBuilder[i].copy(heapType, map);
auto currGroup = heapType.getRecGroup();
if (lastGroup != currGroup && currGroup.size() > 1) {
Expand All @@ -606,11 +557,11 @@ void fuseTypeImportsAndTypeExports() {

// Map old types to the new ones.
GlobalTypeRewriter::TypeMap oldToNewTypes;
for (HeapType heapType : heapTypes) {
for (HeapType heapType : indexedTypes.types) {
HeapType type = heapType;
type = initialMap(type);
if (!type.isBasic()) {
type = newTypes[typeIndices[type]];
type = newTypes[indexedTypes.indices[type]];
}
oldToNewTypes[heapType] = type;
}
Expand Down