Skip to content

Conversation

@felixblanke
Copy link
Collaborator

Adds lazy initialization to wavelet packet objects and addresses #91.

Adds the option to lazy initialize a wavelet packet object by passing lazy_init=True to __init__ or transform.
In the case, we avoid doing a expansion of the full packet tree.
If the user tries to access access a key not yet contained in the dict, we calculate the missing coeffs on the fly.

This allows for siginificant speedups if we are not interested in the full tree but only a subset of the nodes

The example of issue #69 is presented below:

import torch, ptwt
max_lev = 4
shape = (512, 512)

test_signal = torch.randn(shape)
full_packet = ptwt.WaveletPacket2D(test_signal, "haar", maxlevel=max_lev)
partial_packet = ptwt.WaveletPacket2D(test_signal, "haar", maxlevel=max_lev, lazy_init=True)

# Full expansion of the wavelet packet tree
wp_keys = ptwt.WaveletPacket2D.get_natural_order(max_lev)

# Partial expansion
keys = ['aaaa', 'aaad', 'aaah', 'aaav', 'aad', 'aah', 'aava', 'aavd',
        'aavh', 'aavv', 'ad', 'ah', 'ava', 'avd', 'avh', 'avv', 'd', 'h',
        'vaa', 'vad', 'vah', 'vav', 'vd', 'vh', 'vv']

print("Partial expansion: keys contained?", all(key in partial_packet for key in keys))

print("Init...")
# lazy initialization
[partial_packet[key] for key in keys]

print("Partial expansion: keys contained?", all(key in partial_packet for key in keys))
print("Partial expansion: wp_keys contained?", all(key in partial_packet for key in wp_keys))

print()
diffs = [((partial_packet[key] - full_packet[key]) ** 2).sum() for key in keys]
print(f"Squared difference: {sum(diffs)=}")

print(f"# Partial keys: {len(partial_packet.keys())}")
print(f"# Full keys: {len(full_packet.keys())}")

which outputs

Partial expansion: keys contained? False
Init...
Partial expansion: keys contained? True
Partial expansion: wp_keys contained? False

Squared difference: sum(diffs)=tensor(0.)
# Partial keys: 33
# Full keys: 341
@felixblanke felixblanke added the enhancement New feature or request label Jun 25, 2024
@v0lta v0lta self-assigned this Jun 26, 2024
@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

I was merging, but then I thought that not computing the entire tree immediately should never be a problem. We eagerly computed the entire tree because we knew we needed it for the deepfake detection project. But generally speaking, we don't know which parts of the tree a user might want to expand. So I am thinking we should never compute the entire tree by default.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

If that's true, we don't need the lazy_init argument and can just change the behaviour under the hood. I am putting a PR together.

@felixblanke
Copy link
Collaborator Author

I merged main into this PR. The product of the tested parameters gets quite high, running the full test suite show 11227 test for the packets module alone. We might want to consider reducing this

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

I think we should not have the lazy_init argument and make computation on request the default. That way this PR does not add extra tests.

@felixblanke
Copy link
Collaborator Author

We could change the default value of lazy_init to True. Then users could decide to opt in to eager initialization.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

But then we have to keep all of the old eager code. I don't see why we would do that. It just adds extra complexity.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

Wait, let's not duplicate work. @felixblanke, are you doing this already?

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

okay I am doing it.

@felixblanke
Copy link
Collaborator Author

Ah, wait. I am on it :D

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

Okay I am done and running the tests.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

It looks like we can just remove the old recursive code without too much of a hassle.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

Turns out it works in most cases:

FAILED tests/test_packets.py::test_partial_expansion_1d[zero-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[zero-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[reflect-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[reflect-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[constant-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[constant-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[boundary-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[boundary-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[zero-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[zero-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[reflect-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[reflect-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[constant-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[constant-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[boundary-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[boundary-db4] - assert False
FAILED tests/test_packets.py::test_inverse_boundary_packet_1d - KeyError: 'Key da not found'
FAILED tests/test_packets.py::test_inverse_boundary_packet_2d - KeyError: 'Key aa not found'

failed with the current code.

@v0lta
Copy link
Owner

v0lta commented Jun 26, 2024

I am leaving for today, but am happy to look at this again tomorrow.

@felixblanke
Copy link
Collaborator Author

I added the function initialize to the packet objects to initialize all coefficients as described by a set of keys. This feels less clumsy than using a list comprehension (and is possibly more memory efficient)

@v0lta
Copy link
Owner

v0lta commented Jun 27, 2024

I am convinced this works. Lets merge.

@v0lta v0lta merged commit fa7af3d into main Jun 27, 2024
@v0lta v0lta deleted the feature/packets-partial-refinement branch June 27, 2024 14:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

3 participants