mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
First working example
This commit is contained in:
parent
516696f3e4
commit
0f88a3a6c3
@ -28,7 +28,6 @@ import torch
|
|||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
CutConcatenate,
|
|
||||||
CutMix,
|
CutMix,
|
||||||
DynamicBucketingSampler,
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
@ -57,21 +56,33 @@ class _SeedWorkers:
|
|||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
We use c.features = None below to suppress the following warnings
|
||||||
|
|
||||||
|
2025-05-29 16:49:55,253 WARNING [data.py:801] Attempting to perturb speed on a
|
||||||
|
DataCut that references pre-computed features. The feature manifest will be
|
||||||
|
detached, as we do not support feature-domain speed perturbation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def perturb_speed(c: Cut):
|
def perturb_speed(c: Cut):
|
||||||
factor = random.uniform(0.9, 1.1)
|
factor = random.choice([0.9, 1.1])
|
||||||
print("perturb_speed factor", factor)
|
c.features = None
|
||||||
|
|
||||||
return lhotse.MonoCut.perturb_speed(c, factor)
|
return lhotse.MonoCut.perturb_speed(c, factor)
|
||||||
|
|
||||||
|
|
||||||
def perturb_volume(c: Cut):
|
def perturb_volume(c: Cut):
|
||||||
factor = random.uniform(0.9, 1.1)
|
factor = random.choice([0.9, 1.1])
|
||||||
print("perturb_volume factor", factor)
|
c.features = None
|
||||||
|
|
||||||
return lhotse.MonoCut.perturb_volume(c, factor)
|
return lhotse.MonoCut.perturb_volume(c, factor)
|
||||||
|
|
||||||
|
|
||||||
def perturb_tempo(c: Cut):
|
def perturb_tempo(c: Cut):
|
||||||
factor = random.uniform(0.9, 1.1)
|
factor = random.choice([0.9, 1.1])
|
||||||
print("perturb_tempo factor", factor)
|
|
||||||
|
c.features = None
|
||||||
return lhotse.MonoCut.perturb_tempo(c, factor)
|
return lhotse.MonoCut.perturb_tempo(c, factor)
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +97,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
|||||||
experiments, e.g.:
|
experiments, e.g.:
|
||||||
- dynamic batch size,
|
- dynamic batch size,
|
||||||
- bucketing samplers,
|
- bucketing samplers,
|
||||||
- cut concatenation,
|
|
||||||
- augmentation,
|
- augmentation,
|
||||||
- on-the-fly feature extraction
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
@ -112,6 +122,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
|||||||
help="""Used only when --mini-libri is False.When enabled,
|
help="""Used only when --mini-libri is False.When enabled,
|
||||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--enable-augmentation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="True to enable augmentation for training set",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--mini-libri",
|
"--mini-libri",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -204,7 +220,12 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
|||||||
sampler_state_dict:
|
sampler_state_dict:
|
||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
|
if self.args.enable_augmentation:
|
||||||
|
logging.info("Augmentation is enabled")
|
||||||
transforms = [perturb_speed, perturb_volume, perturb_tempo]
|
transforms = [perturb_speed, perturb_volume, perturb_tempo]
|
||||||
|
else:
|
||||||
|
logging.info("Augmentation is disabled")
|
||||||
|
transforms = []
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = ConsistencyRegularizationSpeechRecognitionDataset(
|
train = ConsistencyRegularizationSpeechRecognitionDataset(
|
||||||
@ -254,12 +275,6 @@ class LibriSpeechAsrDataModuleWithParallelAug:
|
|||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
|
||||||
transforms = [
|
|
||||||
CutConcatenate(
|
|
||||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
|
||||||
)
|
|
||||||
] + transforms
|
|
||||||
|
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
|
@ -855,19 +855,38 @@ def compute_loss(
|
|||||||
disables autograd.
|
disables autograd.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
|
||||||
|
feature_len_seq = [batch["supervisions"]["num_frames"]]
|
||||||
|
text_seq = list(batch["supervisions"]["text"])
|
||||||
|
feature_seq = torch.nn.utils.rnn.unpad_sequence(
|
||||||
|
batch["inputs"],
|
||||||
|
batch["supervisions"]["num_frames"],
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "aug" in batch:
|
||||||
|
for aug in batch["aug"]:
|
||||||
|
feature_len_seq.append(aug["supervisions"]["num_frames"])
|
||||||
|
text_seq.extend(aug["supervisions"]["text"])
|
||||||
|
|
||||||
|
feature_seq.extend(
|
||||||
|
torch.nn.utils.rnn.unpad_sequence(
|
||||||
|
aug["inputs"],
|
||||||
|
aug["supervisions"]["num_frames"],
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
feature_lens = torch.cat(feature_len_seq).to(device)
|
||||||
|
feature = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True).to(device)
|
||||||
|
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
y = sp.encode(text_seq, out_type=int)
|
||||||
y = sp.encode(texts, out_type=int)
|
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
@ -1029,6 +1048,9 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
if "aug" in batch:
|
||||||
|
batch_size *= len(batch["aug"]) + 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(
|
with torch.cuda.amp.autocast(
|
||||||
enabled=params.use_autocast, dtype=params.dtype
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
@ -1360,7 +1382,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if False and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
@ -1443,12 +1465,32 @@ def display_and_save_batch(
|
|||||||
logging.info(f"Saving batch to {filename}")
|
logging.info(f"Saving batch to {filename}")
|
||||||
torch.save(batch, filename)
|
torch.save(batch, filename)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
feature_len_seq = [batch["supervisions"]["num_frames"]]
|
||||||
features = batch["inputs"]
|
text_seq = list(batch["supervisions"]["text"])
|
||||||
|
feature_seq = torch.nn.utils.rnn.unpad_sequence(
|
||||||
|
batch["inputs"],
|
||||||
|
batch["supervisions"]["num_frames"],
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "aug" in batch:
|
||||||
|
for aug in batch["aug"]:
|
||||||
|
feature_len_seq.append(aug["supervisions"]["num_frames"])
|
||||||
|
text_seq.extend(aug["supervisions"]["text"])
|
||||||
|
|
||||||
|
feature_seq.extend(
|
||||||
|
torch.nn.utils.rnn.unpad_sequence(
|
||||||
|
aug["inputs"],
|
||||||
|
aug["supervisions"]["num_frames"],
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
features = torch.nn.utils.rnn.pad_sequence(feature_seq, batch_first=True)
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
y = sp.encode(supervisions["text"], out_type=int)
|
y = sp.encode(text_seq, out_type=int)
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user