mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
re-org the bbpe recipe for aishell
This commit is contained in:
parent
ad94191055
commit
674390e63e
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#### Zipformer (Byte-level BPE)
|
#### Zipformer (Byte-level BPE)
|
||||||
|
|
||||||
[./zipformer_bbpe](./zipformer_bbpe/)
|
[./zipformer](./zipformer/)
|
||||||
|
|
||||||
It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500.
|
It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `vocab_size` set to 500.
|
||||||
|
|
||||||
@ -21,14 +21,14 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1"
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
./zipformer_bbpe/train.py \
|
./zipformer/train_bbpe.py \
|
||||||
--world-size 2 \
|
--world-size 2 \
|
||||||
--num-epochs 40 \
|
--num-epochs 40 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--context-size 2 \
|
--context-size 2 \
|
||||||
--enable-musan 0 \
|
--enable-musan 0 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp_bbpe \
|
||||||
--max-duration 1000 \
|
--max-duration 1000 \
|
||||||
--enable-musan 0 \
|
--enable-musan 0 \
|
||||||
--base-lr 0.045 \
|
--base-lr 0.045 \
|
||||||
@ -40,11 +40,11 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
|||||||
Command for decoding is:
|
Command for decoding is:
|
||||||
```bash
|
```bash
|
||||||
for m in greedy_search modified_beam_search fast_beam_search ; do
|
for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||||
./zipformer/decode.py \
|
./zipformer/decode_bbpe.py \
|
||||||
--epoch 40 \
|
--epoch 40 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--exp-dir ./zipformer_bbpe/exp \
|
--exp-dir ./zipformer_bbpe/exp \
|
||||||
--lang-dir data/lang_bbpe_500 \
|
--bpe-model data/lang_bbpe_500/bbpe.model \
|
||||||
--context-size 2 \
|
--context-size 2 \
|
||||||
--decoding-method $m
|
--decoding-method $m
|
||||||
done
|
done
|
||||||
|
@ -93,7 +93,6 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
|
||||||
fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search,
|
greedy_search,
|
@ -53,40 +53,40 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from typing import Optional, Tuple, Union
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from decoder import Decoder
|
|
||||||
from joiner import Joiner
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
|
||||||
from subsampling import Conv2dSubsampling
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer2
|
from train import (
|
||||||
|
LRSchedulerType,
|
||||||
|
add_model_arguments,
|
||||||
|
get_adjusted_batch_count,
|
||||||
|
get_model,
|
||||||
|
get_params,
|
||||||
|
load_checkpoint_if_available,
|
||||||
|
save_checkpoint,
|
||||||
|
set_batch_count,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall import byte_encode, diagnostics
|
from icefall import byte_encode, diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -97,148 +97,6 @@ from icefall.utils import (
|
|||||||
tokenize_by_CJK_char,
|
tokenize_by_CJK_char,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
|
||||||
|
|
||||||
|
|
||||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
|
||||||
# returns the number of batches we would have used so far if we had used the reference
|
|
||||||
# duration. This is for purposes of set_batch_count().
|
|
||||||
return (
|
|
||||||
params.batch_idx_train
|
|
||||||
* (params.max_duration * params.world_size)
|
|
||||||
/ params.ref_duration
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
# get underlying nn.Module
|
|
||||||
model = model.module
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if hasattr(module, "batch_count"):
|
|
||||||
module.batch_count = batch_count
|
|
||||||
if hasattr(module, "name"):
|
|
||||||
module.name = name
|
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-encoder-layers",
|
|
||||||
type=str,
|
|
||||||
default="2,2,3,4,3,2",
|
|
||||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--downsampling-factor",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,4,2",
|
|
||||||
help="Downsampling factor for each stack of encoder layers.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--feedforward-dim",
|
|
||||||
type=str,
|
|
||||||
default="512,768,1024,1536,1024,768",
|
|
||||||
help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-heads",
|
|
||||||
type=str,
|
|
||||||
default="4,4,4,8,4,4",
|
|
||||||
help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-dim",
|
|
||||||
type=str,
|
|
||||||
default="192,256,384,512,384,256",
|
|
||||||
help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--query-head-dim",
|
|
||||||
type=str,
|
|
||||||
default="32",
|
|
||||||
help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--value-head-dim",
|
|
||||||
type=str,
|
|
||||||
default="12",
|
|
||||||
help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pos-head-dim",
|
|
||||||
type=str,
|
|
||||||
default="4",
|
|
||||||
help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pos-dim",
|
|
||||||
type=int,
|
|
||||||
default="48",
|
|
||||||
help="Positional-encoding embedding dimension",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-unmasked-dim",
|
|
||||||
type=str,
|
|
||||||
default="192,192,256,256,256,192",
|
|
||||||
help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cnn-module-kernel",
|
|
||||||
type=str,
|
|
||||||
default="31,31,15,15,15,31",
|
|
||||||
help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-dim",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Embedding dimension in the decoder model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--joiner-dim",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="""Dimension used in the joiner model.
|
|
||||||
Outputs from the encoder and decoder model are projected
|
|
||||||
to this dimension before adding.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="If True, use causal version of model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--chunk-size",
|
|
||||||
type=str,
|
|
||||||
default="16,32,64,-1",
|
|
||||||
help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context-frames",
|
|
||||||
type=str,
|
|
||||||
default="64,128,256,-1",
|
|
||||||
help="""Maximum left-contexts for causal training, measured in frames which will
|
|
||||||
be converted to a number of chunks. If splitting into chunks,
|
|
||||||
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -446,273 +304,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
"""Return a dict containing training parameters.
|
|
||||||
|
|
||||||
All training related parameters that are not passed from the commandline
|
|
||||||
are saved in the variable `params`.
|
|
||||||
|
|
||||||
Commandline options are merged into `params` after they are parsed, so
|
|
||||||
you can also access them via `params`.
|
|
||||||
|
|
||||||
Explanation of options saved in `params`:
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
|
||||||
the model that has the lowest training loss. It is
|
|
||||||
updated during the training.
|
|
||||||
|
|
||||||
- best_valid_loss: Best validation loss so far. It is used to select
|
|
||||||
the model that has the lowest validation loss. It is
|
|
||||||
updated during the training.
|
|
||||||
|
|
||||||
- best_train_epoch: It is the epoch that has the best training loss.
|
|
||||||
|
|
||||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
|
||||||
|
|
||||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
|
||||||
contains number of batches trained so far across
|
|
||||||
epochs.
|
|
||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
|
||||||
|
|
||||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
|
||||||
in computing features.
|
|
||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
|
||||||
|
|
||||||
- encoder_dim: Hidden dim for multi-head attention model.
|
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
|
||||||
|
|
||||||
- warm_step: The warmup period that dictates the decay of the
|
|
||||||
scale on "simple" (un-pruned) loss.
|
|
||||||
"""
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
"best_train_loss": float("inf"),
|
|
||||||
"best_valid_loss": float("inf"),
|
|
||||||
"best_train_epoch": -1,
|
|
||||||
"best_valid_epoch": -1,
|
|
||||||
"batch_idx_train": 0,
|
|
||||||
"log_interval": 50,
|
|
||||||
"reset_interval": 200,
|
|
||||||
"valid_interval": 3000,
|
|
||||||
# parameters for zipformer
|
|
||||||
"feature_dim": 80,
|
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
|
||||||
"warm_step": 2000,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def _to_int_tuple(s: str):
|
|
||||||
return tuple(map(int, s.split(",")))
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|
||||||
# encoder_embed converts the input of shape (N, T, num_features)
|
|
||||||
# to the shape (N, (T - 7) // 2, encoder_dims).
|
|
||||||
# That is, it does two things simultaneously:
|
|
||||||
# (1) subsampling: T -> (T - 7) // 2
|
|
||||||
# (2) embedding: num_features -> encoder_dims
|
|
||||||
# In the normal configuration, we will downsample once more at the end
|
|
||||||
# by a factor of 2, and most of the encoder stacks will run at a lower
|
|
||||||
# sampling rate.
|
|
||||||
encoder_embed = Conv2dSubsampling(
|
|
||||||
in_channels=params.feature_dim,
|
|
||||||
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
|
||||||
)
|
|
||||||
return encoder_embed
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder = Zipformer2(
|
|
||||||
output_downsampling_factor=2,
|
|
||||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
|
||||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
|
||||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
|
||||||
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
|
||||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
|
||||||
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
|
||||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
|
||||||
pos_dim=params.pos_dim,
|
|
||||||
num_heads=_to_int_tuple(params.num_heads),
|
|
||||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
|
||||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
|
||||||
warmup_batches=4000.0,
|
|
||||||
causal=params.causal,
|
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
|
||||||
)
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|
||||||
decoder = Decoder(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
decoder_dim=params.decoder_dim,
|
|
||||||
blank_id=params.blank_id,
|
|
||||||
context_size=params.context_size,
|
|
||||||
)
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|
||||||
joiner = Joiner(
|
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
|
||||||
decoder_dim=params.decoder_dim,
|
|
||||||
joiner_dim=params.joiner_dim,
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
)
|
|
||||||
return joiner
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
|
||||||
encoder = get_encoder_model(params)
|
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
|
||||||
|
|
||||||
model = AsrModel(
|
|
||||||
encoder_embed=encoder_embed,
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
joiner=joiner,
|
|
||||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
|
||||||
decoder_dim=params.decoder_dim,
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: nn.Module,
|
|
||||||
model_avg: nn.Module = None,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Load checkpoint from file.
|
|
||||||
|
|
||||||
If params.start_batch is positive, it will load the checkpoint from
|
|
||||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
|
||||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
|
||||||
`params.start_epoch - 1`.
|
|
||||||
|
|
||||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
|
||||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
|
||||||
and `best_valid_loss` in `params`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
The return value of :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The training model.
|
|
||||||
model_avg:
|
|
||||||
The stored model averaged from the start of training.
|
|
||||||
optimizer:
|
|
||||||
The optimizer that we are using.
|
|
||||||
scheduler:
|
|
||||||
The scheduler that we are using.
|
|
||||||
Returns:
|
|
||||||
Return a dict containing previously saved training info.
|
|
||||||
"""
|
|
||||||
if params.start_batch > 0:
|
|
||||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
|
||||||
elif params.start_epoch > 1:
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert filename.is_file(), f"{filename} does not exist!"
|
|
||||||
|
|
||||||
saved_params = load_checkpoint(
|
|
||||||
filename,
|
|
||||||
model=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
keys = [
|
|
||||||
"best_train_epoch",
|
|
||||||
"best_valid_epoch",
|
|
||||||
"batch_idx_train",
|
|
||||||
"best_train_loss",
|
|
||||||
"best_valid_loss",
|
|
||||||
]
|
|
||||||
for k in keys:
|
|
||||||
params[k] = saved_params[k]
|
|
||||||
|
|
||||||
if params.start_batch > 0:
|
|
||||||
if "cur_epoch" in saved_params:
|
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
|
||||||
|
|
||||||
if "cur_batch_idx" in saved_params:
|
|
||||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
|
||||||
|
|
||||||
return saved_params
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
model_avg: Optional[nn.Module] = None,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
|
||||||
sampler: Optional[CutSampler] = None,
|
|
||||||
scaler: Optional[GradScaler] = None,
|
|
||||||
rank: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The training model.
|
|
||||||
model_avg:
|
|
||||||
The stored model averaged from the start of training.
|
|
||||||
optimizer:
|
|
||||||
The optimizer used in the training.
|
|
||||||
sampler:
|
|
||||||
The sampler for the training dataset.
|
|
||||||
scaler:
|
|
||||||
The scaler used for mix precision training.
|
|
||||||
"""
|
|
||||||
if rank != 0:
|
|
||||||
return
|
|
||||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
|
||||||
save_checkpoint_impl(
|
|
||||||
filename=filename,
|
|
||||||
model=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
params=params,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
sampler=sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.best_train_epoch == params.cur_epoch:
|
|
||||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
|
||||||
copyfile(src=filename, dst=best_train_filename)
|
|
||||||
|
|
||||||
if params.best_valid_epoch == params.cur_epoch:
|
|
||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -1287,7 +878,6 @@ def display_and_save_batch(
|
|||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
texts = supervisions["text"]
|
|
||||||
y = sp.encode(supervisions["text"], out_type=int)
|
y = sp.encode(supervisions["text"], out_type=int)
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/asr_datamodule.py
|
|
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/beam_search.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/decoder.py
|
|
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/encoder_interface.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/joiner.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/model.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/optim.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/scaling.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/scaling_converter.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/subsampling.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer/zipformer.py
|
|
Loading…
x
Reference in New Issue
Block a user