Skip to content

Commit f670319

Browse files
committed
training asr
1 parent ae3eb0f commit f670319

File tree

7 files changed

+18
-18
lines changed

7 files changed

+18
-18
lines changed

‎.gitignore‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#custom
22
ctcdecode/
3-
3+
wandb/
4+
yttm*
45

56
# Byte-compiled / optimized / DLL files
67
__pycache__/

‎configs/config.yml‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
dataset:
2-
root: data/custom_data
2+
root: data/
33
train_part: 0.95
4-
name: vlsp
4+
name: custom_data
55
sample_rate: 22050
66
bpe:
77
train: true
88
model_path: yttm.bpe
99
train:
1010
seed: 42
11-
num_workers: 16
12-
batch_size: 32
11+
num_workers: 4
12+
batch_size: 1
1313
clip_grad_norm: 15
1414
epochs: 42
1515
optimizer:

‎data/.gitignore‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
vlsp*/
2-
custom_data
2+
custom_data
3+
LJ*

‎data/custom.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ def get_text(path_to_script_file):
1010
return text
1111

1212

13-
def make(root='data', name='vlsp2020_train_set_02', save='custom_data'):
13+
def make(root='data', name='vlsp2020_train_set_02', save='LJSpeech-1.1'):
1414
if not os.path.isdir(os.path.join(root, save, 'wavs')):
1515
os.mkdir(os.path.join(root, save, 'wavs'))
1616
all_audio_file_paths = glob.glob(os.path.join(root, name,"*.wav"))
1717
all_script_file_paths = glob.glob(os.path.join(root, name,"*.txt"))
1818
for audio_file in tqdm(all_audio_file_paths):
19-
file_name = audio_file.split('\\')[-1]
19+
file_name = audio_file.split('/')[-1]
2020
os.rename(audio_file, os.path.join(root,save,"wavs",file_name))
2121
with open(os.path.join(root,save,'metadata.csv'), 'w', encoding='UTF8', newline='') as f:
2222
writer = csv.writer(f, delimiter="|")
2323
for text_file in tqdm(all_script_file_paths):
24-
file_name = text_file.split("\\")[-1].split(".")[0]
24+
file_name = text_file.split("/")[-1].split(".")[0]
2525
text = get_text(text_file)
2626
row = [file_name,text,text]
2727
writer.writerow(row)

‎datasets/dataset.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ def __getitem__(self, idx):
1515
return self.transforms({'audio' : audio, 'text': norm_text, 'sample_rate': sample_rate})
1616

1717
def get_text(self, n):
18-
line = self._walker[n]
18+
line = self._flist[n]
1919
fileid, transcript, normalized_transcript = line
2020
return self.transforms({'text' : normalized_transcript})['text']
2121

2222

23-
2423
def get_dataset(config, transforms=lambda x: x, part='train'):
2524
if part == 'train':
2625
dataset = LJSpeechDataset(root=config.dataset.root, download=False, transforms=transforms)

‎tools/train.py‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def train(config):
130130

131131
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
132132
# criterion = nn.CTCLoss(blank=config.model.vocab_size)
133-
decoder = BeamCTCDecoder(bpe=bpe)
133+
decoder = GreedyDecoder(bpe=bpe)
134134

135135
prev_wer = 1000
136136
wandb.init(project=config.wandb.project, config=config)
@@ -162,10 +162,9 @@ def train(config):
162162
"train_cer": cer,
163163
"train_samples": wandb.Table(
164164
columns=['gt_text', 'pred_text'],
165-
data=zip(target_strings, decoded_output)
165+
data=list(zip(target_strings, decoded_output))
166166
)
167167
}, step=step)
168-
169168
# validate:
170169
model.eval()
171170
val_stats = defaultdict(list)
@@ -202,7 +201,7 @@ def train(config):
202201

203202
if __name__ == '__main__':
204203
parser = argparse.ArgumentParser(description='Training model.')
205-
parser.add_argument('--config', default='configs/train_LJSpeech.yml',
204+
parser.add_argument('--config', default='configs/config.yml',
206205
help='path to config file')
207206
args = parser.parse_args()
208207
with open(args.config, 'r') as f:

‎utils/utils.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import youtokentome as yttm
66
import os
77
import importlib
8-
from data.transforms import TextPreprocess
8+
from datasets.transforms import TextPreprocess
9+
from datasets.dataset import get_dataset
910

1011

1112
def fix_seeds(seed=42):
@@ -22,10 +23,9 @@ def remove_from_dict(the_dict, keys):
2223
return the_dict
2324

2425
def prepare_bpe(config):
25-
dataset_module = importlib.import_module(f'.{config.dataset.name}', data.__name__)
2626
# train BPE
2727
if config.bpe.get('train', False):
28-
dataset, ids = dataset_module.get_dataset(config, part='bpe', transforms=TextPreprocess())
28+
dataset, ids = get_dataset(config, part='bpe', transforms=TextPreprocess())
2929
train_data_path = 'bpe_texts.txt'
3030
with open(train_data_path, "w") as f:
3131
# run ovefr only train part

0 commit comments

Comments
 (0)