Incorporate changes from master into pruned_transducer_stateless2.

This commit is contained in:
Daniel Povey 2022-03-21 16:51:48 +08:00
parent 05b5e78d8f
commit ccbf8ba086
4 changed files with 253 additions and 51 deletions

View File

@ -42,6 +42,17 @@ Usage:
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -49,16 +60,26 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
modified_beam_search,
)
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -88,6 +109,17 @@ def get_parser():
"'--epoch'. ", "'--epoch'. ",
) )
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -110,6 +142,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -117,8 +150,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -144,6 +204,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -166,6 +227,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -184,36 +248,62 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if params.decoding_method == "fast_beam_search":
# fmt: off hyp_tokens = fast_beam_search(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on decoding_graph=decoding_graph,
if params.decoding_method == "greedy_search": encoder_out=encoder_out,
hyp = greedy_search( encoder_out_lens=encoder_out_lens,
model=model, beam=params.beam,
encoder_out=encoder_out_i, max_contexts=params.max_contexts,
max_sym_per_frame=params.max_sym_per_frame, max_states=params.max_states,
) )
elif params.decoding_method == "beam_search": for hyp in sp.decode(hyp_tokens):
hyp = beam_search( hyps.append(hyp.split())
model=model, encoder_out=encoder_out_i, beam=params.beam_size else:
) batch_size = encoder_out.size(0)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search( for i in range(batch_size):
model=model, encoder_out=encoder_out_i, beam=params.beam_size # fmt: off
) encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
else: # fmt: on
raise ValueError( if params.decoding_method == "greedy_search":
f"Unsupported decoding method: {params.decoding_method}" hyp = greedy_search(
) model=model,
hyps.append(sp.decode(hyp).split()) encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -221,6 +311,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -233,6 +324,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -260,6 +354,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -340,12 +435,17 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
@ -372,7 +472,12 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg == 1: if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
@ -388,6 +493,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -408,6 +518,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -64,6 +64,7 @@ class Decoder(nn.Module):
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1: if context_size > 1:
self.conv = ScaledConv1d( self.conv = ScaledConv1d(
in_channels=embedding_dim, in_channels=embedding_dim,

View File

@ -36,7 +36,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Any, Dict, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -48,6 +48,7 @@ from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
@ -55,8 +56,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall import diagnostics from icefall import diagnostics
@ -112,6 +114,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -192,6 +203,30 @@ def get_parser():
help="Accumulate stats on activations, print them and exit.", help="Accumulate stats on activations, print them and exit.",
) )
parser.add_argument(
"--save-every-n",
type=int,
default=8000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=20,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
return parser return parser
@ -320,15 +355,16 @@ def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> Optional[Dict[str, Any]]:
) -> None:
"""Load checkpoint from file. """Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing. `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`, Apart from loading state dict for `model` and `optimizer` it also updates
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`. and `best_valid_loss` in `params`.
Args: Args:
@ -338,20 +374,22 @@ def load_checkpoint_if_available(
The training model. The training model.
optimizer: optimizer:
The optimizer that we are using. The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns: Returns:
Return None. Return a dict containing previously saved training info.
""" """
if params.start_epoch <= 0: if params.start_batch > 0:
return filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
) )
keys = [ keys = [
@ -360,10 +398,13 @@ def load_checkpoint_if_available(
"batch_idx_train", "batch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
"cur_batch_idx",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
params["start_epoch"] = saved_params["cur_epoch"]
return saved_params return saved_params
@ -371,7 +412,7 @@ def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -381,6 +422,10 @@ def save_checkpoint(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The training model. The training model.
optimizer:
The optimizer used in the training.
sampler:
The sampler for the training dataset.
""" """
if rank != 0: if rank != 0:
return return
@ -390,7 +435,7 @@ def save_checkpoint(
model=model, model=model,
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, sampler=sampler,
rank=rank, rank=rank,
) )
@ -509,6 +554,7 @@ def train_one_epoch(
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0,
) -> None: ) -> None:
"""Train the model for one epoch. """Train the model for one epoch.
@ -531,12 +577,21 @@ def train_one_epoch(
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled. Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
""" """
model.train() model.train()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -560,6 +615,27 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
@ -688,7 +764,14 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = librispeech.train_dataloaders(train_cuts) if checkpoints and "sampler" in checkpoints:
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
@ -728,6 +811,7 @@ def run(rank, world_size, args):
valid_dl=valid_dl, valid_dl=valid_dl,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank,
) )
if params.print_diagnostics: if params.print_diagnostics:
@ -738,6 +822,7 @@ def run(rank, world_size, args):
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
sampler=train_dl.sampler,
rank=rank, rank=rank,
) )

View File

@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
return "" return ""
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
stats, _ = torch.symeig(stats) try:
stats = stats.abs().sqrt() eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except:
print("Error getting eigenvalues, trying another method")
eigs, _ = torch.eigs(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same: elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)