mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Incorporate changes from master into pruned_transducer_stateless2.
This commit is contained in:
parent
05b5e78d8f
commit
ccbf8ba086
@ -42,6 +42,17 @@ Usage:
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
"""
|
||||
|
||||
|
||||
@ -49,16 +60,26 @@ import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
@ -88,6 +109,17 @@ def get_parser():
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -110,6 +142,7 @@ def get_parser():
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
@ -117,8 +150,35 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
beam_search or modified_beam_search""",
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -144,6 +204,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -166,6 +227,9 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -184,36 +248,62 @@ def decode_one_batch(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
else:
|
||||
return {f"beam_{params.beam_size}": hyps}
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -221,6 +311,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -233,6 +324,9 @@ def decode_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -260,6 +354,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -340,12 +435,17 @@ def main():
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if "beam_search" in params.decoding_method:
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
@ -372,7 +472,12 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg == 1:
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
@ -388,6 +493,11 @@ def main():
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -408,6 +518,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -64,6 +64,7 @@ class Decoder(nn.Module):
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
if context_size > 1:
|
||||
self.conv = ScaledConv1d(
|
||||
in_channels=embedding_dim,
|
||||
|
@ -36,7 +36,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -48,6 +48,7 @@ from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
@ -55,8 +56,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall import diagnostics
|
||||
@ -112,6 +114,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -192,6 +203,30 @@ def get_parser():
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
It does not affect checkpoints with name `epoch-xxx.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -320,15 +355,16 @@ def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
@ -338,20 +374,22 @@ def load_checkpoint_if_available(
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 0:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -360,10 +398,13 @@ def load_checkpoint_if_available(
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
"cur_batch_idx",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -371,7 +412,7 @@ def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -381,6 +422,10 @@ def save_checkpoint(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
The sampler for the training dataset.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
@ -390,7 +435,7 @@ def save_checkpoint(
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
@ -509,6 +554,7 @@ def train_one_epoch(
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
@ -531,12 +577,21 @@ def train_one_epoch(
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx < cur_batch_idx:
|
||||
continue
|
||||
cur_batch_idx = batch_idx
|
||||
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -560,6 +615,27 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
@ -688,7 +764,14 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||
if checkpoints and "sampler" in checkpoints:
|
||||
sampler_state_dict = checkpoints["sampler"]
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
@ -728,6 +811,7 @@ def run(rank, world_size, args):
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
@ -738,6 +822,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
sampler=train_dl.sampler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
|
||||
return ""
|
||||
count = sum(counts)
|
||||
stats = stats / count
|
||||
stats, _ = torch.symeig(stats)
|
||||
stats = stats.abs().sqrt()
|
||||
try:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except:
|
||||
print("Error getting eigenvalues, trying another method")
|
||||
eigs, _ = torch.eigs(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||
elif sizes_same:
|
||||
stats = torch.stack(stats).sum(dim=0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user