First working example

This commit is contained in:
Fangjun Kuang 2025-05-30 15:42:31 +08:00
parent 516696f3e4
commit 0f88a3a6c3
2 changed files with 83 additions and 26 deletions

View File

@ -28,7 +28,6 @@ import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
@ -57,21 +56,33 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id)
"""
We use c.features = None below to suppress the following warnings
2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a
DataCut that references pre-computed features. The feature manifest will be
detached, as we do not support feature-domain speed perturbation.
"""
def perturb_speed(c: Cut):
factor = random.uniform(0.9, 1.1)
print("perturb_speed factor", factor)
factor = random.choice([0.9, 1.1])
c.features = None
return lhotse.MonoCut.perturb_speed(c, factor)
def perturb_volume(c: Cut):
factor = random.uniform(0.9, 1.1)
print("perturb_volume factor", factor)
factor = random.choice([0.9, 1.1])
c.features = None
return lhotse.MonoCut.perturb_volume(c, factor)
def perturb_tempo(c: Cut):
factor = random.uniform(0.9, 1.1)
print("perturb_tempo factor", factor)
factor = random.choice([0.9, 1.1])
c.features = None
return lhotse.MonoCut.perturb_tempo(c, factor)
@ -86,7 +97,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
@ -112,6 +122,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
)
group.add_argument(
"--enable-augmentation",
type=str2bool,
default=True,
help="True to enable augmentation for training set",
)
group.add_argument(
"--mini-libri",
type=str2bool,
@ -204,7 +220,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = [perturb_speed, perturb_volume, perturb_tempo]
if self.args.enable_augmentation:
logging.info("Augmentation is enabled")
transforms = [perturb_speed, perturb_volume, perturb_tempo]
else:
logging.info("Augmentation is disabled")
transforms = []
logging.info("About to create train dataset")
train = ConsistencyRegularizationSpeechRecognitionDataset(
@ -254,12 +275,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:

View File

@ -855,19 +855,38 @@ def compute_loss(
disables autograd.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
feature_len_seq = [batch["supervisions"]["num_frames"]]
text_seq = list(batch["supervisions"]["text"])
feature_seq = torch.nn.utils.rnn.unpad_sequence(
batch["inputs"],
batch["supervisions"]["num_frames"],
batch_first=True,
)
if "aug" in batch:
for aug in batch["aug"]:
feature_len_seq.append(aug["supervisions"]["num_frames"])
text_seq.extend(aug["supervisions"]["text"])
feature_seq.extend(
torch.nn.utils.rnn.unpad_sequence(
aug["inputs"],
aug["supervisions"]["num_frames"],
batch_first=True,
)
)
feature_lens = torch.cat(feature_len_seq).to(device)
feature = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True).to(device)
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = sp.encode(text_seq, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
@ -1029,6 +1048,9 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if "aug" in batch:
batch_size *= len(batch["aug"]) + 1
try:
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
@ -1360,7 +1382,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if False and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -1443,12 +1465,32 @@ def display_and_save_batch(
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
feature_len_seq = [batch["supervisions"]["num_frames"]]
text_seq = list(batch["supervisions"]["text"])
feature_seq = torch.nn.utils.rnn.unpad_sequence(
batch["inputs"],
batch["supervisions"]["num_frames"],
batch_first=True,
)
if "aug" in batch:
for aug in batch["aug"]:
feature_len_seq.append(aug["supervisions"]["num_frames"])
text_seq.extend(aug["supervisions"]["text"])
feature_seq.extend(
torch.nn.utils.rnn.unpad_sequence(
aug["inputs"],
aug["supervisions"]["num_frames"],
batch_first=True,
)
)
features = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True)
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
y = sp.encode(text_seq, out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")