Skip to content

Commit 47493ee

Browse files
committed
Add general nd padding function
1 parent 58da7e6 commit 47493ee

File tree

4 files changed

+122
-70
lines changed

4 files changed

+122
-70
lines changed

‎src/ptwt/_util.py‎

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing_extensions import ParamSpec, TypeVar
1717

1818
from .constants import (
19+
BoundaryMode,
1920
OrthogonalizeMethod,
2021
WaveletCoeff2d,
2122
WaveletCoeffNd,
@@ -103,6 +104,31 @@ def _get_len(wavelet: Union[tuple[torch.Tensor, ...], str, Wavelet]) -> int:
103104
return len(_as_wavelet(wavelet))
104105

105106

107+
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
108+
"""Translate pywt mode strings to PyTorch mode strings.
109+
110+
We support constant, zero, reflect, and periodic.
111+
Unfortunately, "constant" has different meanings in the
112+
Pytorch and PyWavelet communities.
113+
114+
Raises:
115+
ValueError: If the padding mode is not supported.
116+
"""
117+
if pywt_mode == "constant":
118+
return "replicate"
119+
elif pywt_mode == "zero":
120+
return "constant"
121+
elif pywt_mode == "reflect":
122+
return pywt_mode
123+
elif pywt_mode == "periodic":
124+
return "circular"
125+
elif pywt_mode == "symmetric":
126+
# pytorch does not support symmetric mode,
127+
# we have our own implementation.
128+
return pywt_mode
129+
raise ValueError(f"Padding mode not supported: {pywt_mode}")
130+
131+
106132
def _is_orthogonalize_method_supported(
107133
orthogonalization: Optional[OrthogonalizeMethod],
108134
) -> bool:
@@ -166,6 +192,94 @@ def _construct_nd_filt(
166192
return filter_tensor
167193

168194

195+
def _fwt_padn(
196+
data: torch.Tensor,
197+
wavelet: Union[Wavelet, str],
198+
ndim: int,
199+
*,
200+
mode: BoundaryMode,
201+
padding: Optional[tuple[int, ...]] = None,
202+
) -> torch.Tensor:
203+
"""Pad data for the Nd-FWT.
204+
205+
This function pads the last :math:`N` axes.
206+
207+
Args:
208+
data (torch.Tensor): Input data with :math:`N+1` dimensions.
209+
wavelet (Wavelet or str): A pywt wavelet compatible object or
210+
the name of a pywt wavelet.
211+
Refer to the output from ``pywt.wavelist(kind='discrete')``
212+
for possible choices.
213+
ndim (int): The number of dimentsions :math:`N`.
214+
mode: The desired padding mode for extending the signal along the edges.
215+
See :data:`ptwt.constants.BoundaryMode`.
216+
padding (tuple[int, ...], optional): A tuple with the number of
217+
padded values on the respective side of each transformed axis
218+
of `data`. Expects to have :math:`2N` entries.
219+
If None, the padding values are computed based
220+
on the signal shape and the wavelet length.
221+
Defaults to None.
222+
223+
Returns:
224+
The padded output tensor.
225+
226+
Raises:
227+
ValueError: If `padding` is not None and has a length different
228+
from :math:`2N`.
229+
"""
230+
if padding is None:
231+
padding_lst: list[int] = []
232+
for dim in range(1, ndim + 1):
233+
pad_axis_r, pad_axis_l = _get_pad(data.shape[-dim], _get_len(wavelet))
234+
padding_lst.extend([pad_axis_l, pad_axis_r])
235+
padding = tuple(padding_lst)
236+
237+
if len(padding) != 2 * ndim:
238+
raise ValueError("Invalid number of padding values passed!")
239+
240+
if mode == "symmetric":
241+
padding_pairs = list(zip(padding[::2], padding[1::2]))
242+
data_pad = _pad_symmetric(data, padding_pairs[::-1])
243+
else:
244+
data_pad = torch.nn.functional.pad(
245+
data, padding, mode=_translate_boundary_strings(mode)
246+
)
247+
return data_pad
248+
249+
250+
def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]:
251+
"""Compute the required padding.
252+
253+
Args:
254+
data_len (int): The length of the input vector.
255+
filt_len (int): The size of the used filter.
256+
257+
Returns:
258+
A tuple (padr, padl). The first entry specifies how many numbers
259+
to attach on the right. The second entry covers the left side.
260+
"""
261+
# pad to ensure we see all filter positions and
262+
# for pywt compatability.
263+
# convolution output length:
264+
# see https://arxiv.org/pdf/1603.07285.pdf section 2.3:
265+
# floor([data_len - filt_len]/2) + 1
266+
# should equal pywt output length
267+
# floor((data_len + filt_len - 1)/2)
268+
# => floor([data_len + total_pad - filt_len]/2) + 1
269+
# = floor((data_len + filt_len - 1)/2)
270+
# (data_len + total_pad - filt_len) + 2 = data_len + filt_len - 1
271+
# total_pad = 2*filt_len - 3
272+
273+
# we pad half of the total requried padding on each side.
274+
padr = (2 * filt_len - 3) // 2
275+
padl = (2 * filt_len - 3) // 2
276+
277+
# pad to even singal length.
278+
padr += data_len % 2
279+
280+
return padr, padl
281+
282+
169283
def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch.Tensor:
170284
padl, padr = pad_list
171285
dimlen = signal.shape[0]

‎src/ptwt/conv_transform.py‎

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
_as_wavelet,
1717
_check_same_device_dtype,
1818
_get_len,
19+
_get_pad,
1920
_pad_symmetric,
2021
_postprocess_coeffs,
2122
_postprocess_tensor,
2223
_preprocess_coeffs,
2324
_preprocess_tensor,
25+
_translate_boundary_strings,
2426
)
2527
from .constants import BoundaryMode, WaveletCoeff2d
2628

@@ -74,64 +76,6 @@ def _get_filter_tensors(
7476
return dec_lo_tensor, dec_hi_tensor, rec_lo_tensor, rec_hi_tensor
7577

7678

77-
def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]:
78-
"""Compute the required padding.
79-
80-
Args:
81-
data_len (int): The length of the input vector.
82-
filt_len (int): The size of the used filter.
83-
84-
Returns:
85-
A tuple (padr, padl). The first entry specifies how many numbers
86-
to attach on the right. The second entry covers the left side.
87-
"""
88-
# pad to ensure we see all filter positions and
89-
# for pywt compatability.
90-
# convolution output length:
91-
# see https://arxiv.org/pdf/1603.07285.pdf section 2.3:
92-
# floor([data_len - filt_len]/2) + 1
93-
# should equal pywt output length
94-
# floor((data_len + filt_len - 1)/2)
95-
# => floor([data_len + total_pad - filt_len]/2) + 1
96-
# = floor((data_len + filt_len - 1)/2)
97-
# (data_len + total_pad - filt_len) + 2 = data_len + filt_len - 1
98-
# total_pad = 2*filt_len - 3
99-
100-
# we pad half of the total requried padding on each side.
101-
padr = (2 * filt_len - 3) // 2
102-
padl = (2 * filt_len - 3) // 2
103-
104-
# pad to even singal length.
105-
padr += data_len % 2
106-
107-
return padr, padl
108-
109-
110-
def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str:
111-
"""Translate pywt mode strings to PyTorch mode strings.
112-
113-
We support constant, zero, reflect, and periodic.
114-
Unfortunately, "constant" has different meanings in the
115-
Pytorch and PyWavelet communities.
116-
117-
Raises:
118-
ValueError: If the padding mode is not supported.
119-
"""
120-
if pywt_mode == "constant":
121-
return "replicate"
122-
elif pywt_mode == "zero":
123-
return "constant"
124-
elif pywt_mode == "reflect":
125-
return pywt_mode
126-
elif pywt_mode == "periodic":
127-
return "circular"
128-
elif pywt_mode == "symmetric":
129-
# pytorch does not support symmetric mode,
130-
# we have our own implementation.
131-
return pywt_mode
132-
raise ValueError(f"Padding mode not supported: {pywt_mode}")
133-
134-
13579
def _fwt_pad(
13680
data: torch.Tensor,
13781
wavelet: Union[Wavelet, str],

‎src/ptwt/conv_transform_2.py‎

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,17 @@
1515
Wavelet,
1616
_check_same_device_dtype,
1717
_get_len,
18+
_get_pad,
1819
_outer,
1920
_pad_symmetric,
2021
_postprocess_coeffs,
2122
_postprocess_tensor,
2223
_preprocess_coeffs,
2324
_preprocess_tensor,
24-
)
25-
from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d
26-
from .conv_transform import (
27-
_adjust_padding_at_reconstruction,
28-
_get_filter_tensors,
29-
_get_pad,
3025
_translate_boundary_strings,
3126
)
27+
from .constants import BoundaryMode, WaveletCoeff2d, WaveletDetailTuple2d
28+
from .conv_transform import _adjust_padding_at_reconstruction, _get_filter_tensors
3229

3330

3431
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:

‎src/ptwt/conv_transform_3.py‎

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,17 @@
1515
_as_wavelet,
1616
_check_same_device_dtype,
1717
_get_len,
18+
_get_pad,
1819
_outer,
1920
_pad_symmetric,
2021
_postprocess_coeffs,
2122
_postprocess_tensor,
2223
_preprocess_coeffs,
2324
_preprocess_tensor,
24-
)
25-
from .constants import BoundaryMode, WaveletCoeffNd, WaveletDetailDict
26-
from .conv_transform import (
27-
_adjust_padding_at_reconstruction,
28-
_get_filter_tensors,
29-
_get_pad,
3025
_translate_boundary_strings,
3126
)
27+
from .constants import BoundaryMode, WaveletCoeffNd, WaveletDetailDict
28+
from .conv_transform import _adjust_padding_at_reconstruction, _get_filter_tensors
3229

3330

3431
def _construct_3d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)