Skip to content

Commit 37f82de

Browse files
committed
use future annotations
1 parent ade4fd6 commit 37f82de

File tree

4 files changed

+62
-56
lines changed

4 files changed

+62
-56
lines changed

‎audiolm_pytorch/audiolm_pytorch.py‎

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from __future__ import annotations
2+
13
import math
24
from functools import partial, wraps
35

4-
from beartype.typing import Optional, Union, List
56
from beartype import beartype
67

78
import torch
@@ -625,7 +626,7 @@ def forward(
625626
*,
626627
ids = None,
627628
return_loss = False,
628-
text: Optional[List[str]] = None,
629+
text: list[str] | None = None,
629630
text_embeds = None,
630631
self_attn_mask = None,
631632
cond_drop_prob = None,
@@ -813,7 +814,7 @@ def forward(
813814
semantic_token_ids,
814815
coarse_token_ids,
815816
self_attn_mask = None,
816-
text: Optional[List[str]] = None,
817+
text: list[str] | None = None,
817818
text_embeds = None,
818819
cond_drop_prob = None,
819820
return_only_coarse_logits = False,
@@ -1089,7 +1090,7 @@ def forward(
10891090
self,
10901091
coarse_token_ids,
10911092
fine_token_ids,
1092-
text: Optional[List[str]] = None,
1093+
text: list[str] | None = None,
10931094
text_embeds = None,
10941095
cond_drop_prob = None,
10951096
self_attn_mask = None,
@@ -1327,8 +1328,8 @@ def __init__(
13271328
self,
13281329
*,
13291330
transformer: SemanticTransformer,
1330-
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
1331-
audio_conditioner: Optional[AudioConditionerBase] = None,
1331+
wav2vec: FairseqVQWav2Vec | HubertWithKmeans | None = None,
1332+
audio_conditioner: AudioConditionerBase | None = None,
13321333
pad_id = -1,
13331334
unique_consecutive = True,
13341335
mask_prob = 0.15
@@ -1362,7 +1363,7 @@ def generate(
13621363
self,
13631364
*,
13641365
max_length,
1365-
text: Optional[List[str]] = None,
1366+
text: list[str] | None = None,
13661367
text_embeds = None,
13671368
prime_wave = None,
13681369
prime_wave_input_sample_hz = None,
@@ -1524,9 +1525,9 @@ def __init__(
15241525
self,
15251526
*,
15261527
transformer: CoarseTransformer,
1527-
codec: Optional[Union[SoundStream, EncodecWrapper]] = None,
1528-
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
1529-
audio_conditioner: Optional[AudioConditionerBase] = None,
1528+
codec: SoundStream | EncodecWrapper | None = None,
1529+
wav2vec: FairseqVQWav2Vec | HubertWithKmeans | None = None,
1530+
audio_conditioner: AudioConditionerBase | None = None,
15301531
pad_id = -1,
15311532
unique_consecutive = True,
15321533
semantic_cross_entropy_loss_weight = 1.,
@@ -1564,10 +1565,10 @@ def generate(
15641565
self,
15651566
*,
15661567
semantic_token_ids,
1567-
prime_wave: Optional[Tensor] = None,
1568+
prime_wave: Tensor | None = None,
15681569
prime_wave_input_sample_hz = None,
1569-
prime_coarse_token_ids: Optional[Tensor] = None,
1570-
text: Optional[List[str]] = None,
1570+
prime_coarse_token_ids: Tensor | None = None,
1571+
text: list[str] | None = None,
15711572
text_embeds = None,
15721573
max_time_steps = 512,
15731574
cond_scale = 3.,
@@ -1811,8 +1812,8 @@ def __init__(
18111812
self,
18121813
*,
18131814
transformer: FineTransformer,
1814-
codec: Optional[Union[SoundStream, EncodecWrapper]] = None,
1815-
audio_conditioner: Optional[AudioConditionerBase] = None,
1815+
codec: SoundStream | EncodecWrapper | None = None,
1816+
audio_conditioner: AudioConditionerBase | None = None,
18161817
coarse_cross_entropy_loss_weight = 1.,
18171818
pad_id = -1,
18181819
mask_prob = 0.15
@@ -1852,10 +1853,10 @@ def generate(
18521853
self,
18531854
*,
18541855
coarse_token_ids,
1855-
prime_wave: Optional[Tensor] = None,
1856+
prime_wave: Tensor | None = None,
18561857
prime_wave_input_sample_hz = None,
1857-
prime_fine_token_ids: Optional[Tensor] = None,
1858-
text: Optional[List[str]] = None,
1858+
prime_fine_token_ids: Tensor | None = None,
1859+
text: list[str] | None = None,
18591860
text_embeds = None,
18601861
cond_scale = 3.,
18611862
filter_thres = 0.9,
@@ -2095,12 +2096,12 @@ class AudioLM(nn.Module):
20952096
def __init__(
20962097
self,
20972098
*,
2098-
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
2099-
codec: Union[SoundStream, EncodecWrapper],
2099+
wav2vec: FairseqVQWav2Vec | HubertWithKmeans | None,
2100+
codec: SoundStream | EncodecWrapper,
21002101
semantic_transformer: SemanticTransformer,
21012102
coarse_transformer: CoarseTransformer,
21022103
fine_transformer: FineTransformer,
2103-
audio_conditioner: Optional[AudioConditionerBase] = None,
2104+
audio_conditioner: AudioConditionerBase | None = None,
21042105
unique_consecutive = True
21052106
):
21062107
super().__init__()
@@ -2148,8 +2149,8 @@ def forward(
21482149
self,
21492150
*,
21502151
batch_size = 1,
2151-
text: Optional[List[str]] = None,
2152-
text_embeds: Optional[Tensor] = None,
2152+
text: list[str] | None = None,
2153+
text_embeds: Tensor | None = None,
21532154
prime_wave = None,
21542155
prime_wave_input_sample_hz = None,
21552156
prime_wave_path = None,

‎audiolm_pytorch/data.py‎

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from pathlib import Path
24
from functools import partial, wraps
35

@@ -35,10 +37,10 @@ class SoundDataset(Dataset):
3537
def __init__(
3638
self,
3739
folder,
38-
target_sample_hz: Union[int, Tuple[int, ...]], # target sample hz must be specified, or a tuple of them if one wants to return multiple resampled
40+
target_sample_hz: int | Tuple[int, ...], # target sample hz must be specified, or a tuple of them if one wants to return multiple resampled
3941
exts = ['flac', 'wav', 'mp3', 'webm'],
40-
max_length: Optional[int] = None, # max length would apply to the highest target_sample_hz, if there are multiple
41-
seq_len_multiple_of: Optional[Union[int, Tuple[Optional[int], ...]]] = None
42+
max_length: int | None = None, # max length would apply to the highest target_sample_hz, if there are multiple
43+
seq_len_multiple_of: int | tuple[int | None, ...] | None = None
4244
):
4345
super().__init__()
4446
path = Path(folder)

‎audiolm_pytorch/soundstream.py‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from __future__ import annotations
2+
13
import functools
24
from pathlib import Path
35
from functools import partial, wraps
46
from itertools import cycle, zip_longest
5-
from typing import Optional, List
67

78
import torch
89
from torch import nn, einsum
@@ -455,8 +456,8 @@ def __init__(
455456
strides = (2, 4, 5, 8),
456457
channel_mults = (2, 4, 8, 16),
457458
codebook_dim = 512,
458-
codebook_size: Optional[int] = None,
459-
finite_scalar_quantizer_levels: Optional[List[int]] = None,
459+
codebook_size: int | None = None,
460+
finite_scalar_quantizer_levels: list[int] | None = None,
460461
rq_num_quantizers = 8,
461462
rq_commitment_weight = 1.,
462463
rq_ema_decay = 0.95,
@@ -492,7 +493,7 @@ def __init__(
492493
squeeze_excite = False,
493494
complex_stft_discr_logits_abs = True,
494495
pad_mode = 'reflect',
495-
stft_discriminator: Optional[Module] = None, # can pass in own stft discriminator
496+
stft_discriminator: Module | None = None, # can pass in own stft discriminator
496497
complex_stft_discr_kwargs: dict = dict()
497498
):
498499
super().__init__()

‎audiolm_pytorch/trainer.py‎

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import re
24
import copy
35
from math import sqrt
@@ -9,7 +11,7 @@
911
from collections import Counter
1012
from contextlib import contextmanager, nullcontext
1113

12-
from beartype.typing import Union, List, Optional, Tuple, Type
14+
from beartype.typing import Type
1315
from typing_extensions import Annotated
1416

1517
from beartype import beartype
@@ -79,7 +81,7 @@ def check_one_trainer():
7981
torch.Tensor,
8082
Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
8183
],
82-
text = List[str],
84+
text = list[str],
8385
text_embeds = Annotated[
8486
torch.Tensor,
8587
Is[lambda t: t.dtype == torch.float and t.ndim == 3]
@@ -166,7 +168,7 @@ def __init__(
166168
self,
167169
accelerator: Accelerator,
168170
optimizer: Optimizer,
169-
scheduler: Optional[Type[_LRScheduler]] = None,
171+
scheduler: Type[_LRScheduler] | None = None,
170172
scheduler_kwargs: dict = dict(),
171173
warmup_steps: int = 0
172174
):
@@ -216,20 +218,20 @@ def __init__(
216218
num_train_steps: int,
217219
batch_size: int,
218220
data_max_length: int = None,
219-
data_max_length_seconds: Union[int, float] = None,
221+
data_max_length_seconds: int | float = None,
220222
folder: str = None,
221-
dataset: Optional[Dataset] = None,
222-
val_dataset: Optional[Dataset] = None,
223-
train_dataloader: Optional[DataLoader] = None,
224-
val_dataloader: Optional[DataLoader] = None,
223+
dataset: Dataset | None = None,
224+
val_dataset: Dataset | None = None,
225+
train_dataloader: DataLoader | None = None,
226+
val_dataloader: DataLoader | None = None,
225227
lr: float = 2e-4,
226228
grad_accum_every: int = 4,
227229
wd: float = 0.,
228230
warmup_steps: int = 1000,
229-
scheduler: Optional[Type[_LRScheduler]] = None,
231+
scheduler: Type[_LRScheduler] | None = None,
230232
scheduler_kwargs: dict = dict(),
231-
discr_warmup_steps: Optional[int] = None,
232-
discr_scheduler: Optional[Type[_LRScheduler]] = None,
233+
discr_warmup_steps: int | None = None,
234+
discr_scheduler: Type[_LRScheduler] | None = None,
233235
discr_scheduler_kwargs: dict = dict(),
234236
max_grad_norm: float = 0.5,
235237
discr_max_grad_norm: float = None,
@@ -245,7 +247,7 @@ def __init__(
245247
ema_update_every: int = 10,
246248
apply_grad_penalty_every: int = 4,
247249
dl_num_workers: int = 0,
248-
accelerator: Optional[Accelerator] = None,
250+
accelerator: Accelerator | None = None,
249251
accelerate_kwargs: dict = dict(),
250252
init_process_group_timeout_seconds = 1800,
251253
dataloader_drop_last = True,
@@ -715,14 +717,14 @@ class SemanticTransformerTrainer(nn.Module):
715717
@beartype
716718
def __init__(
717719
self,
718-
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
720+
wav2vec: FairseqVQWav2Vec | HubertWithKmeans | None,
719721
transformer: SemanticTransformer,
720722
*,
721723
num_train_steps,
722724
batch_size,
723-
audio_conditioner: Optional[AudioConditionerBase] = None,
724-
dataset: Optional[Dataset] = None,
725-
valid_dataset: Optional[Dataset] = None,
725+
audio_conditioner: AudioConditionerBase | None = None,
726+
dataset: Dataset | None = None,
727+
valid_dataset: Dataset | None = None,
726728
data_max_length = None,
727729
data_max_length_seconds = None,
728730
folder = None,
@@ -1009,15 +1011,15 @@ class CoarseTransformerTrainer(nn.Module):
10091011
def __init__(
10101012
self,
10111013
transformer: CoarseTransformer,
1012-
codec: Union[SoundStream, EncodecWrapper],
1013-
wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
1014+
codec: SoundStream | EncodecWrapper,
1015+
wav2vec: FairseqVQWav2Vec | HubertWithKmeans | None,
10141016
*,
10151017
num_train_steps,
10161018
batch_size,
1017-
audio_conditioner: Optional[AudioConditionerBase] = None,
1018-
dataset: Optional[Dataset] = None,
1019-
valid_dataset: Optional[Dataset] = None,
1020-
ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),
1019+
audio_conditioner: AudioConditionerBase | None = None,
1020+
dataset: Dataset | None = None,
1021+
valid_dataset: Dataset | None = None,
1022+
ds_fields: tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),
10211023
data_max_length = None,
10221024
data_max_length_seconds = None,
10231025
folder = None,
@@ -1311,13 +1313,13 @@ class FineTransformerTrainer(nn.Module):
13111313
def __init__(
13121314
self,
13131315
transformer: FineTransformer,
1314-
codec: Union[SoundStream, EncodecWrapper],
1316+
codec: SoundStream | EncodecWrapper,
13151317
*,
13161318
num_train_steps,
13171319
batch_size,
1318-
audio_conditioner: Optional[AudioConditionerBase] = None,
1319-
dataset: Optional[Dataset] = None,
1320-
valid_dataset: Optional[Dataset] = None,
1320+
audio_conditioner: AudioConditionerBase | None = None,
1321+
dataset: Dataset | None = None,
1322+
valid_dataset: Dataset | None = None,
13211323
data_max_length = None,
13221324
data_max_length_seconds = None,
13231325
dataset_normalize = False,

0 commit comments

Comments
 (0)