mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
pruned-transducer-stateless5 recipe for aishell4
This commit is contained in:
parent
b0e565a253
commit
4215ec434a
@ -29,7 +29,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, ChunkedLilcomHdf5Writer
|
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -54,7 +54,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80):
|
|||||||
"test",
|
"test",
|
||||||
)
|
)
|
||||||
manifests = read_manifests_if_cached(
|
manifests = read_manifests_if_cached(
|
||||||
dataset_parts=dataset_parts,
|
dataset_parts=dataset_parts,
|
||||||
output_dir=src_dir,
|
output_dir=src_dir,
|
||||||
prefix="aishell4",
|
prefix="aishell4",
|
||||||
suffix="jsonl.gz",
|
suffix="jsonl.gz",
|
||||||
|
@ -64,7 +64,6 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import Aishell4AsrDataModule
|
from asr_datamodule import Aishell4AsrDataModule
|
||||||
@ -83,6 +82,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
@ -152,7 +152,7 @@ def get_parser():
|
|||||||
"lexicon.txt"
|
"lexicon.txt"
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -628,7 +628,7 @@ def main():
|
|||||||
shuffle_shards=True,
|
shuffle_shards=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
|
test_dl = aishell4.test_dataloaders(cuts_test_webdataset)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
test_dl = [test_dl]
|
test_dl = [test_dl]
|
||||||
|
@ -81,6 +81,7 @@ from icefall.checkpoint import (
|
|||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[
|
||||||
@ -752,29 +753,25 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
loss, loss_info = compute_loss(
|
||||||
loss, loss_info = compute_loss(
|
params=params,
|
||||||
params=params,
|
model=model,
|
||||||
model=model,
|
graph_compiler=graph_compiler,
|
||||||
graph_compiler=graph_compiler,
|
batch=batch,
|
||||||
batch=batch,
|
is_training=True,
|
||||||
is_training=True,
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
)
|
||||||
)
|
# summary stats
|
||||||
# summary stats
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
|
||||||
raise
|
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -897,7 +894,7 @@ def run(rank, world_size, args):
|
|||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
@ -1039,43 +1036,11 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
def display_and_save_batch(
|
|
||||||
batch: dict,
|
|
||||||
params: AttributeDict,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
) -> None:
|
|
||||||
"""Display the batch statistics and save the batch into disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch:
|
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
|
||||||
for the content in it.
|
|
||||||
params:
|
|
||||||
Parameters for training. See :func:`get_params`.
|
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
"""
|
|
||||||
from lhotse.utils import uuid4
|
|
||||||
|
|
||||||
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
|
||||||
logging.info(f"Saving batch to {filename}")
|
|
||||||
torch.save(batch, filename)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
features = batch["inputs"]
|
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
|
||||||
|
|
||||||
y = sp.encode(supervisions["text"], out_type=int)
|
|
||||||
num_tokens = sum(len(i) for i in y)
|
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -1094,7 +1059,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=0.0,
|
||||||
@ -1111,7 +1076,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user