|
16 | 16 | from typing_extensions import ParamSpec, TypeVar |
17 | 17 |
|
18 | 18 | from .constants import ( |
| 19 | + BoundaryMode, |
19 | 20 | OrthogonalizeMethod, |
20 | 21 | WaveletCoeff2d, |
21 | 22 | WaveletCoeffNd, |
@@ -103,6 +104,31 @@ def _get_len(wavelet: Union[tuple[torch.Tensor, ...], str, Wavelet]) -> int: |
103 | 104 | return len(_as_wavelet(wavelet)) |
104 | 105 |
|
105 | 106 |
|
| 107 | +def _translate_boundary_strings(pywt_mode: BoundaryMode) -> str: |
| 108 | + """Translate pywt mode strings to PyTorch mode strings. |
| 109 | +
|
| 110 | + We support constant, zero, reflect, and periodic. |
| 111 | + Unfortunately, "constant" has different meanings in the |
| 112 | + Pytorch and PyWavelet communities. |
| 113 | +
|
| 114 | + Raises: |
| 115 | + ValueError: If the padding mode is not supported. |
| 116 | + """ |
| 117 | + if pywt_mode == "constant": |
| 118 | + return "replicate" |
| 119 | + elif pywt_mode == "zero": |
| 120 | + return "constant" |
| 121 | + elif pywt_mode == "reflect": |
| 122 | + return pywt_mode |
| 123 | + elif pywt_mode == "periodic": |
| 124 | + return "circular" |
| 125 | + elif pywt_mode == "symmetric": |
| 126 | + # pytorch does not support symmetric mode, |
| 127 | + # we have our own implementation. |
| 128 | + return pywt_mode |
| 129 | + raise ValueError(f"Padding mode not supported: {pywt_mode}") |
| 130 | + |
| 131 | + |
106 | 132 | def _is_orthogonalize_method_supported( |
107 | 133 | orthogonalization: Optional[OrthogonalizeMethod], |
108 | 134 | ) -> bool: |
@@ -166,6 +192,94 @@ def _construct_nd_filt( |
166 | 192 | return filter_tensor |
167 | 193 |
|
168 | 194 |
|
| 195 | +def _fwt_padn( |
| 196 | + data: torch.Tensor, |
| 197 | + wavelet: Union[Wavelet, str], |
| 198 | + ndim: int, |
| 199 | + *, |
| 200 | + mode: BoundaryMode, |
| 201 | + padding: Optional[tuple[int, ...]] = None, |
| 202 | +) -> torch.Tensor: |
| 203 | + """Pad data for the Nd-FWT. |
| 204 | +
|
| 205 | + This function pads the last :math:`N` axes. |
| 206 | +
|
| 207 | + Args: |
| 208 | + data (torch.Tensor): Input data with :math:`N+1` dimensions. |
| 209 | + wavelet (Wavelet or str): A pywt wavelet compatible object or |
| 210 | + the name of a pywt wavelet. |
| 211 | + Refer to the output from ``pywt.wavelist(kind='discrete')`` |
| 212 | + for possible choices. |
| 213 | + ndim (int): The number of dimentsions :math:`N`. |
| 214 | + mode: The desired padding mode for extending the signal along the edges. |
| 215 | + See :data:`ptwt.constants.BoundaryMode`. |
| 216 | + padding (tuple[int, ...], optional): A tuple with the number of |
| 217 | + padded values on the respective side of each transformed axis |
| 218 | + of `data`. Expects to have :math:`2N` entries. |
| 219 | + If None, the padding values are computed based |
| 220 | + on the signal shape and the wavelet length. |
| 221 | + Defaults to None. |
| 222 | +
|
| 223 | + Returns: |
| 224 | + The padded output tensor. |
| 225 | +
|
| 226 | + Raises: |
| 227 | + ValueError: If `padding` is not None and has a length different |
| 228 | + from :math:`2N`. |
| 229 | + """ |
| 230 | + if padding is None: |
| 231 | + padding_lst: list[int] = [] |
| 232 | + for dim in range(1, ndim + 1): |
| 233 | + pad_axis_r, pad_axis_l = _get_pad(data.shape[-dim], _get_len(wavelet)) |
| 234 | + padding_lst.extend([pad_axis_l, pad_axis_r]) |
| 235 | + padding = tuple(padding_lst) |
| 236 | + |
| 237 | + if len(padding) != 2 * ndim: |
| 238 | + raise ValueError("Invalid number of padding values passed!") |
| 239 | + |
| 240 | + if mode == "symmetric": |
| 241 | + padding_pairs = list(zip(padding[::2], padding[1::2])) |
| 242 | + data_pad = _pad_symmetric(data, padding_pairs[::-1]) |
| 243 | + else: |
| 244 | + data_pad = torch.nn.functional.pad( |
| 245 | + data, padding, mode=_translate_boundary_strings(mode) |
| 246 | + ) |
| 247 | + return data_pad |
| 248 | + |
| 249 | + |
| 250 | +def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]: |
| 251 | + """Compute the required padding. |
| 252 | +
|
| 253 | + Args: |
| 254 | + data_len (int): The length of the input vector. |
| 255 | + filt_len (int): The size of the used filter. |
| 256 | +
|
| 257 | + Returns: |
| 258 | + A tuple (padr, padl). The first entry specifies how many numbers |
| 259 | + to attach on the right. The second entry covers the left side. |
| 260 | + """ |
| 261 | + # pad to ensure we see all filter positions and |
| 262 | + # for pywt compatability. |
| 263 | + # convolution output length: |
| 264 | + # see https://arxiv.org/pdf/1603.07285.pdf section 2.3: |
| 265 | + # floor([data_len - filt_len]/2) + 1 |
| 266 | + # should equal pywt output length |
| 267 | + # floor((data_len + filt_len - 1)/2) |
| 268 | + # => floor([data_len + total_pad - filt_len]/2) + 1 |
| 269 | + # = floor((data_len + filt_len - 1)/2) |
| 270 | + # (data_len + total_pad - filt_len) + 2 = data_len + filt_len - 1 |
| 271 | + # total_pad = 2*filt_len - 3 |
| 272 | + |
| 273 | + # we pad half of the total requried padding on each side. |
| 274 | + padr = (2 * filt_len - 3) // 2 |
| 275 | + padl = (2 * filt_len - 3) // 2 |
| 276 | + |
| 277 | + # pad to even singal length. |
| 278 | + padr += data_len % 2 |
| 279 | + |
| 280 | + return padr, padl |
| 281 | + |
| 282 | + |
169 | 283 | def _pad_symmetric_1d(signal: torch.Tensor, pad_list: tuple[int, int]) -> torch.Tensor: |
170 | 284 | padl, padr = pad_list |
171 | 285 | dimlen = signal.shape[0] |
|
0 commit comments