pruned-transducer-stateless5 recipe for aishell4

This commit is contained in:
luomingshuang 2022-06-05 19:53:37 +08:00
parent b0e565a253
commit 4215ec434a
3 changed files with 27 additions and 63 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, ChunkedLilcomHdf5Writer
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -54,7 +54,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
"test",
)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix="aishell4",
suffix="jsonl.gz",

View File

@ -64,7 +64,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule
@ -83,6 +82,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
@ -152,7 +152,7 @@ def get_parser():
"lexicon.txt"
""",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -628,7 +628,7 @@ def main():
shuffle_shards=True,
)
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
test_dl = aishell4.test_dataloaders(cuts_test_webdataset)
test_sets = ["test"]
test_dl = [test_dl]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
@ -752,29 +753,25 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
return
@ -897,7 +894,7 @@ def run(rank, world_size, args):
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
@ -1039,43 +1036,11 @@ def run(rank, world_size, args):
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -1094,7 +1059,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
warmup=0.0,
@ -1111,7 +1076,6 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
raise