66
77from __future__ import annotations
88
9- from functools import partial
109from typing import Optional , Union
1110
1211import pywt
1312import torch
1413
1514from ._util import (
1615 Wavelet ,
17- _as_wavelet ,
18- _check_axes_argument ,
19- _check_if_tensor ,
20- _fold_axes ,
16+ _check_same_device_dtype ,
2117 _get_len ,
22- _is_dtype_supported ,
23- _map_result ,
2418 _outer ,
2519 _pad_symmetric ,
26- _swap_axes ,
27- _undo_swap_axes ,
28- _unfold_axes ,
20+ _postprocess_coeffs ,
21+ _postprocess_tensor ,
22+ _preprocess_coeffs ,
23+ _preprocess_tensor ,
2924)
3025from .constants import BoundaryMode , WaveletCoeff2d , WaveletDetailTuple2d
3126from .conv_transform import (
@@ -107,32 +102,6 @@ def _fwt_pad2(
107102 return data_pad
108103
109104
110- def _waverec2d_fold_channels_2d_list (
111- coeffs : WaveletCoeff2d ,
112- ) -> tuple [WaveletCoeff2d , list [int ]]:
113- # fold the input coefficients for processing conv2d_transpose.
114- ds = list (_check_if_tensor (coeffs [0 ]).shape )
115- return _map_result (coeffs , lambda t : _fold_axes (t , 2 )[0 ]), ds
116-
117-
118- def _preprocess_tensor_dec2d (
119- data : torch .Tensor ,
120- ) -> tuple [torch .Tensor , Union [list [int ], None ]]:
121- # Preprocess multidimensional input.
122- ds = None
123- if len (data .shape ) == 2 :
124- data = data .unsqueeze (0 ).unsqueeze (0 )
125- elif len (data .shape ) == 3 :
126- # add a channel dimension for torch.
127- data = data .unsqueeze (1 )
128- elif len (data .shape ) >= 4 :
129- data , ds = _fold_axes (data , 2 )
130- data = data .unsqueeze (1 )
131- elif len (data .shape ) == 1 :
132- raise ValueError ("More than one input dimension required." )
133- return data , ds
134-
135-
136105def wavedec2 (
137106 data : torch .Tensor ,
138107 wavelet : Union [Wavelet , str ],
@@ -183,11 +152,6 @@ def wavedec2(
183152 A tuple containing the wavelet coefficients in pywt order,
184153 see :data:`ptwt.constants.WaveletCoeff2d`.
185154
186- Raises:
187- ValueError: If the dimensionality or the dtype of the input data tensor
188- is unsupported or if the provided ``axes``
189- input has a length other than two.
190-
191155 Example:
192156 >>> import torch
193157 >>> import ptwt, pywt
@@ -200,17 +164,7 @@ def wavedec2(
200164 >>> level=2, mode="zero")
201165
202166 """
203- if not _is_dtype_supported (data .dtype ):
204- raise ValueError (f"Input dtype { data .dtype } not supported" )
205-
206- if tuple (axes ) != (- 2 , - 1 ):
207- if len (axes ) != 2 :
208- raise ValueError ("2D transforms work with two axes." )
209- else :
210- data = _swap_axes (data , list (axes ))
211-
212- wavelet = _as_wavelet (wavelet )
213- data , ds = _preprocess_tensor_dec2d (data )
167+ data , ds = _preprocess_tensor (data , ndim = 2 , axes = axes )
214168 dec_lo , dec_hi , _ , _ = _get_filter_tensors (
215169 wavelet , flip = True , device = data .device , dtype = data .dtype
216170 )
@@ -234,13 +188,7 @@ def wavedec2(
234188 res_ll = res_ll .squeeze (1 )
235189 result : WaveletCoeff2d = res_ll , * result_lst
236190
237- if ds :
238- _unfold_axes2 = partial (_unfold_axes , ds = ds , keep_no = 2 )
239- result = _map_result (result , _unfold_axes2 )
240-
241- if axes != (- 2 , - 1 ):
242- undo_swap_fn = partial (_undo_swap_axes , axes = axes )
243- result = _map_result (result , undo_swap_fn )
191+ result = _postprocess_coeffs (result , ndim = 2 , ds = ds , axes = axes )
244192
245193 return result
246194
@@ -286,35 +234,16 @@ def waverec2(
286234 >>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
287235
288236 """
289- if tuple (axes ) != (- 2 , - 1 ):
290- if len (axes ) != 2 :
291- raise ValueError ("2D transforms work with two axes." )
292- else :
293- _check_axes_argument (list (axes ))
294- swap_fn = partial (_swap_axes , axes = list (axes ))
295- coeffs = _map_result (coeffs , swap_fn )
296-
297- ds = None
298- wavelet = _as_wavelet (wavelet )
299-
300- res_ll = _check_if_tensor (coeffs [0 ])
301- torch_device = res_ll .device
302- torch_dtype = res_ll .dtype
303-
304- if res_ll .dim () >= 4 :
305- # avoid the channel sum, fold the channels into batches.
306- coeffs , ds = _waverec2d_fold_channels_2d_list (coeffs )
307- res_ll = _check_if_tensor (coeffs [0 ])
308-
309- if not _is_dtype_supported (torch_dtype ):
310- raise ValueError (f"Input dtype { torch_dtype } not supported" )
237+ coeffs , ds = _preprocess_coeffs (coeffs , ndim = 2 , axes = axes )
238+ torch_device , torch_dtype = _check_same_device_dtype (coeffs )
311239
312240 _ , _ , rec_lo , rec_hi = _get_filter_tensors (
313241 wavelet , flip = False , device = torch_device , dtype = torch_dtype
314242 )
315243 filt_len = rec_lo .shape [- 1 ]
316244 rec_filt = _construct_2d_filt (lo = rec_lo , hi = rec_hi )
317245
246+ res_ll = coeffs [0 ]
318247 for c_pos , coeff_tuple in enumerate (coeffs [1 :]):
319248 if not isinstance (coeff_tuple , tuple ) or len (coeff_tuple ) != 3 :
320249 raise ValueError (
@@ -325,11 +254,7 @@ def waverec2(
325254
326255 curr_shape = res_ll .shape
327256 for coeff in coeff_tuple :
328- if torch_device != coeff .device :
329- raise ValueError ("coefficients must be on the same device" )
330- elif torch_dtype != coeff .dtype :
331- raise ValueError ("coefficients must have the same dtype" )
332- elif coeff .shape != curr_shape :
257+ if coeff .shape != curr_shape :
333258 raise ValueError (
334259 "All coefficients on each level must have the same shape"
335260 )
@@ -362,10 +287,6 @@ def waverec2(
362287 if padr > 0 :
363288 res_ll = res_ll [..., :- padr ]
364289
365- if ds :
366- res_ll = _unfold_axes (res_ll , list (ds ), 2 )
367-
368- if axes != (- 2 , - 1 ):
369- res_ll = _undo_swap_axes (res_ll , list (axes ))
290+ res_ll = _postprocess_tensor (res_ll , ndim = 2 , ds = ds , axes = axes )
370291
371292 return res_ll
0 commit comments