First working example

This commit is contained in:
Fangjun Kuang 2025-05-30 15:42:31 +08:00
parent 516696f3e4
commit 0f88a3a6c3
2 changed files with 83 additions and 26 deletions

View File

@ -28,7 +28,6 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
@ -57,21 +56,33 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id) fix_random_seed(self.seed + worker_id)
"""
We use c.features = None below to suppress the following warnings
2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a
DataCut that references pre-computed features. The feature manifest will be
detached, as we do not support feature-domain speed perturbation.
"""
def perturb_speed(c: Cut): def perturb_speed(c: Cut):
factor = random.uniform(0.9, 1.1) factor = random.choice([0.9, 1.1])
print("perturb_speed factor", factor) c.features = None
return lhotse.MonoCut.perturb_speed(c, factor) return lhotse.MonoCut.perturb_speed(c, factor)
def perturb_volume(c: Cut): def perturb_volume(c: Cut):
factor = random.uniform(0.9, 1.1) factor = random.choice([0.9, 1.1])
print("perturb_volume factor", factor) c.features = None
return lhotse.MonoCut.perturb_volume(c, factor) return lhotse.MonoCut.perturb_volume(c, factor)
def perturb_tempo(c: Cut): def perturb_tempo(c: Cut):
factor = random.uniform(0.9, 1.1) factor = random.choice([0.9, 1.1])
print("perturb_tempo factor", factor)
c.features = None
return lhotse.MonoCut.perturb_tempo(c, factor) return lhotse.MonoCut.perturb_tempo(c, factor)
@ -86,7 +97,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
experiments, e.g.: experiments, e.g.:
- dynamic batch size, - dynamic batch size,
- bucketing samplers, - bucketing samplers,
- cut concatenation,
- augmentation, - augmentation,
- on-the-fly feature extraction - on-the-fly feature extraction
@ -112,6 +122,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
help="""Used only when --mini-libri is False.When enabled, help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""", use 960h LibriSpeech. Otherwise, use 100h subset.""",
) )
group.add_argument(
"--enable-augmentation",
type=str2bool,
default=True,
help="True to enable augmentation for training set",
)
group.add_argument( group.add_argument(
"--mini-libri", "--mini-libri",
type=str2bool, type=str2bool,
@ -204,7 +220,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
sampler_state_dict: sampler_state_dict:
The state dict for the training sampler. The state dict for the training sampler.
""" """
if self.args.enable_augmentation:
logging.info("Augmentation is enabled")
transforms = [perturb_speed, perturb_volume, perturb_tempo] transforms = [perturb_speed, perturb_volume, perturb_tempo]
else:
logging.info("Augmentation is disabled")
transforms = []
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = ConsistencyRegularizationSpeechRecognitionDataset( train = ConsistencyRegularizationSpeechRecognitionDataset(
@ -254,12 +275,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = [] transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:

View File

@ -855,19 +855,38 @@ def compute_loss(
disables autograd. disables autograd.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
feature_len_seq = [batch["supervisions"]["num_frames"]]
text_seq = list(batch["supervisions"]["text"])
feature_seq = torch.nn.utils.rnn.unpad_sequence(
batch["inputs"],
batch["supervisions"]["num_frames"],
batch_first=True,
)
if "aug" in batch:
for aug in batch["aug"]:
feature_len_seq.append(aug["supervisions"]["num_frames"])
text_seq.extend(aug["supervisions"]["text"])
feature_seq.extend(
torch.nn.utils.rnn.unpad_sequence(
aug["inputs"],
aug["supervisions"]["num_frames"],
batch_first=True,
)
)
feature_lens = torch.cat(feature_len_seq).to(device)
feature = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True).to(device)
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
warm_step = params.warm_step warm_step = params.warm_step
texts = batch["supervisions"]["text"] y = sp.encode(text_seq, out_type=int)
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
@ -1029,6 +1048,9 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
if "aug" in batch:
batch_size *= len(batch["aug"]) + 1
try: try:
with torch.cuda.amp.autocast( with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype enabled=params.use_autocast, dtype=params.dtype
@ -1360,7 +1382,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if False and not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
@ -1443,12 +1465,32 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}") logging.info(f"Saving batch to {filename}")
torch.save(batch, filename) torch.save(batch, filename)
supervisions = batch["supervisions"] feature_len_seq = [batch["supervisions"]["num_frames"]]
features = batch["inputs"] text_seq = list(batch["supervisions"]["text"])
feature_seq = torch.nn.utils.rnn.unpad_sequence(
batch["inputs"],
batch["supervisions"]["num_frames"],
batch_first=True,
)
if "aug" in batch:
for aug in batch["aug"]:
feature_len_seq.append(aug["supervisions"]["num_frames"])
text_seq.extend(aug["supervisions"]["text"])
feature_seq.extend(
torch.nn.utils.rnn.unpad_sequence(
aug["inputs"],
aug["supervisions"]["num_frames"],
batch_first=True,
)
)
features = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True)
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int) y = sp.encode(text_seq, out_type=int)
num_tokens = sum(len(i) for i in y) num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}") logging.info(f"num tokens: {num_tokens}")