pruned-transducer-stateless5 recipe for aishell4

This commit is contained in:
luomingshuang 2022-06-05 19:53:37 +08:00
parent b0e565a253
commit 4215ec434a
3 changed files with 27 additions and 63 deletions

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, ChunkedLilcomHdf5Writer from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -54,7 +54,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
"test", "test",
) )
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, dataset_parts=dataset_parts,
output_dir=src_dir, output_dir=src_dir,
prefix="aishell4", prefix="aishell4",
suffix="jsonl.gz", suffix="jsonl.gz",

View File

@ -64,7 +64,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import Aishell4AsrDataModule from asr_datamodule import Aishell4AsrDataModule
@ -83,6 +82,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -152,7 +152,7 @@ def get_parser():
"lexicon.txt" "lexicon.txt"
""", """,
) )
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -628,7 +628,7 @@ def main():
shuffle_shards=True, shuffle_shards=True,
) )
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset) test_dl = aishell4.test_dataloaders(cuts_test_webdataset)
test_sets = ["test"] test_sets = ["test"]
test_dl = [test_dl] test_dl = [test_dl]

View File

@ -81,6 +81,7 @@ from icefall.checkpoint import (
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[ LRSchedulerType = Union[
@ -752,29 +753,25 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss(
loss, loss_info = compute_loss( params=params,
params=params, model=model,
model=model, graph_compiler=graph_compiler,
graph_compiler=graph_compiler, batch=batch,
batch=batch, is_training=True,
is_training=True, warmup=(params.batch_idx_train / params.model_warm_step),
warmup=(params.batch_idx_train / params.model_warm_step), )
) # summary stats
# summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
@ -897,7 +894,7 @@ def run(rank, world_size, args):
lexicon=lexicon, lexicon=lexicon,
device=device, device=device,
) )
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
@ -1039,43 +1036,11 @@ def run(rank, world_size, args):
cleanup_dist() cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1094,7 +1059,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp, graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=0.0,
@ -1111,7 +1076,6 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, sp=sp)
raise raise