mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
First working example
This commit is contained in:
parent
516696f3e4
commit
0f88a3a6c3
@ -28,7 +28,6 @@ import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
@ -57,21 +56,33 @@ class _SeedWorkers:
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
"""
|
||||
We use c.features = None below to suppress the following warnings
|
||||
|
||||
2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a
|
||||
DataCut that references pre-computed features. The feature manifest will be
|
||||
detached, as we do not support feature-domain speed perturbation.
|
||||
"""
|
||||
|
||||
|
||||
def perturb_speed(c: Cut):
|
||||
factor = random.uniform(0.9, 1.1)
|
||||
print("perturb_speed factor", factor)
|
||||
factor = random.choice([0.9, 1.1])
|
||||
c.features = None
|
||||
|
||||
return lhotse.MonoCut.perturb_speed(c, factor)
|
||||
|
||||
|
||||
def perturb_volume(c: Cut):
|
||||
factor = random.uniform(0.9, 1.1)
|
||||
print("perturb_volume factor", factor)
|
||||
factor = random.choice([0.9, 1.1])
|
||||
c.features = None
|
||||
|
||||
return lhotse.MonoCut.perturb_volume(c, factor)
|
||||
|
||||
|
||||
def perturb_tempo(c: Cut):
|
||||
factor = random.uniform(0.9, 1.1)
|
||||
print("perturb_tempo factor", factor)
|
||||
factor = random.choice([0.9, 1.1])
|
||||
|
||||
c.features = None
|
||||
return lhotse.MonoCut.perturb_tempo(c, factor)
|
||||
|
||||
|
||||
@ -86,7 +97,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
@ -112,6 +122,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
||||
help="""Used only when --mini-libri is False.When enabled,
|
||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--enable-augmentation",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="True to enable augmentation for training set",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mini-libri",
|
||||
type=str2bool,
|
||||
@ -204,7 +220,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
if self.args.enable_augmentation:
|
||||
logging.info("Augmentation is enabled")
|
||||
transforms = [perturb_speed, perturb_volume, perturb_tempo]
|
||||
else:
|
||||
logging.info("Augmentation is disabled")
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = ConsistencyRegularizationSpeechRecognitionDataset(
|
||||
@ -254,12 +275,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
|
@ -855,19 +855,38 @@ def compute_loss(
|
||||
disables autograd.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
|
||||
feature_len_seq = [batch["supervisions"]["num_frames"]]
|
||||
text_seq = list(batch["supervisions"]["text"])
|
||||
feature_seq = torch.nn.utils.rnn.unpad_sequence(
|
||||
batch["inputs"],
|
||||
batch["supervisions"]["num_frames"],
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
if "aug" in batch:
|
||||
for aug in batch["aug"]:
|
||||
feature_len_seq.append(aug["supervisions"]["num_frames"])
|
||||
text_seq.extend(aug["supervisions"]["text"])
|
||||
|
||||
feature_seq.extend(
|
||||
torch.nn.utils.rnn.unpad_sequence(
|
||||
aug["inputs"],
|
||||
aug["supervisions"]["num_frames"],
|
||||
batch_first=True,
|
||||
)
|
||||
)
|
||||
|
||||
feature_lens = torch.cat(feature_len_seq).to(device)
|
||||
feature = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True).to(device)
|
||||
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = sp.encode(text_seq, out_type=int)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -1029,6 +1048,9 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
if "aug" in batch:
|
||||
batch_size *= len(batch["aug"]) + 1
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
@ -1360,7 +1382,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if False and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
@ -1443,12 +1465,32 @@ def display_and_save_batch(
|
||||
logging.info(f"Saving batch to {filename}")
|
||||
torch.save(batch, filename)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
features = batch["inputs"]
|
||||
feature_len_seq = [batch["supervisions"]["num_frames"]]
|
||||
text_seq = list(batch["supervisions"]["text"])
|
||||
feature_seq = torch.nn.utils.rnn.unpad_sequence(
|
||||
batch["inputs"],
|
||||
batch["supervisions"]["num_frames"],
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
if "aug" in batch:
|
||||
for aug in batch["aug"]:
|
||||
feature_len_seq.append(aug["supervisions"]["num_frames"])
|
||||
text_seq.extend(aug["supervisions"]["text"])
|
||||
|
||||
feature_seq.extend(
|
||||
torch.nn.utils.rnn.unpad_sequence(
|
||||
aug["inputs"],
|
||||
aug["supervisions"]["num_frames"],
|
||||
batch_first=True,
|
||||
)
|
||||
)
|
||||
|
||||
features = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True)
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
y = sp.encode(supervisions["text"], out_type=int)
|
||||
y = sp.encode(text_seq, out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user