-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
Description
Describe the bug
I trained a model on french only (and tried english only) data using canary-1b-v2 as encoder and meta-llama/Llama-3.2-1B-Instruct as LLM. I trained for 50k steps on 1 GPU (lasted 3h) and the last checkpoint has a WER of 171 (evaluated using the salm_eval script).
Here a few of the rows:
{"duration": 3.96, "text": "la principale religion en moldavie est le christianisme orthodoxe", "pred_text": "qu est ce que vous faites pour vous faire plaisir de vous faire du mal"}
{"duration": 5.04, "text": "les premiers cas de cette maladie saisonniere ont ete declares fin juillet", "pred_text": "il est possible que les gens ne soient pas tres interesses par les problemes de la vie quotidienne mais ils sont tres interesses par les problemes de la vie politique"}
{"duration": 5.1, "text": "trois autres bombes ont explose pres des batiments du gouvernement en l espace de deux heures", "pred_text": "je suis desole mais je ne peux pas vous donner de nouvelles sur le cas de la femme qui a disparu"}
Does 3h not enough to starting to see transcriptions?
Steps/Code to reproduce bug
Dataset used : FLEURS and Multilingual LibriSpeech, both in French
Config :
############ Model ############
model:
pretrained_llm: meta-llama/Llama-3.2-1B-Instruct
pretrained_asr: nvidia/canary-1b-v2
pretrained_weights: True # When False, we use pretrained_name to load the architecture, but with random init
# Training log prediction, do not work
# log_prediction_train: true
# log_prediction_train_samples: 5
# log_prediction_train_interval: 50
freeze_params:
# Frozen LLM
- "^llm\\..+$" # LLM
- "^embed_tokens\\..+$" # LLM embedding is moved
# Frozen pretrained ASR (only the modality adapter layers are trainable)
- "^perception\\.preprocessor\\..+$"
- "^perception\\.encoder\\..+$"
prevent_freeze_params: [] # Use to make specific submodules trainable; overrides freeze_params
prompt_format: llama3
audio_locator_tag: "<|audio|>"
perception:
target: nemo.collections.speechlm2.modules.perception.AudioPerceptionModule
modality_adapter:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: 1024
feat_out: -1 # you may set it if you need different output size other than the default d_model
n_layers: 2
d_model: 1024
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 1 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false
ff_expansion_factor: 4
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1]
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000
conv_kernel_size: 9
conv_norm_type: batch_norm # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
# conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size
# null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0]
conv_context_size: null
### regularization
dropout: 0 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0 # The dropout for multi-headed attention modules
optimizer:
_target_: torch.optim.AdamW
lr: 1e-5
betas: [0.9, 0.98]
weight_decay: 1e-3
foreach: true # set to false if having issues with tensor-parallelism
lr_scheduler:
_target_: nemo.core.optim.lr_scheduler.CosineAnnealing
warmup_steps: 1000
min_lr: 1e-6
max_steps: ${trainer.max_steps}
trainer:
devices: -1
accelerator: gpu
num_nodes: 1
precision: bf16-true
max_epochs: -1
max_steps: 50000
limit_train_batches: 10_000 # number of steps in 1 epoch bc no epoch in lhotse
accumulate_grad_batches: 2
limit_val_batches: 0.0
val_check_interval: 5000
num_sanity_val_steps: 1
sync_batchnorm: true
log_every_n_steps: 5
logger: false
use_distributed_sampler: false
enable_checkpointing: false
gradient_clip_val: 1.0
strategy:
# Replace DDPStrategy with ModelParallelStrategy to enable model parallelism
_target_: lightning.pytorch.strategies.DDPStrategy
gradient_as_bucket_view: true
find_unused_parameters: true
# _target_: lightning.pytorch.strategies.ModelParallelStrategy
# tensor_parallel_size: 2
# data_parallel_size: 2
exp_manager:
exp_dir: null
name: test
create_tensorboard_logger: true
create_checkpoint_callback: true
use_datetime_version: false
max_time_per_run: 00:01:45:00
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to True to continue the training
resume_if_exists: true
resume_ignore_no_checkpoint: true
############ Data ############
data:
train_ds:
sample_rate: 16000
prompt_format: ${model.prompt_format}
token_equivalent_duration: 0.08
input_cfg:
- type: group
tags:
lang: en
pnc: yes
task: asr
input_cfg:
- type: multimodal_conversation
manifest_filepath: manifest_FLEURS_casepunc_train.jsonl
audio_locator_tag: ${model.audio_locator_tag}
- type: multimodal_conversation
manifest_filepath: Multilingual_LibriSpeech/manifest_Multilingual_LibriSpeech_recasepunc_train.jsonl
audio_locator_tag: ${model.audio_locator_tag}
seed: 42
shuffle: true
shard_seed: "randomized"
text_field: 'answer'
use_lhotse: true
num_workers: 4
# batch_duration: 480
# batch_size: 32
max_duration: 90
# min_duration: 0.1
# max_tokens: 1024
min_tokens: 3
# Optional bucketing:
batch_size: null
use_bucketing: true
use_multimodal_sampling: true
measure_total_length: true
batch_tokens: 4096
max_tokens: 1024
bucket_duration_bins: [128, 256, 512, 1024]
num_buckets: 4
bucket_buffer_size: 5_000Environment details
Nemo 2.4 on H100
Additional context
Add any other context about the problem here.
Example: GPU model