mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Shallow fusion for Aishell (#954)
* add shallow fusion and LODR for aishell * update RESULTS * add save by iterations
This commit is contained in:
parent
46bf6df62f
commit
d337398d29
@ -15,6 +15,8 @@ It uses pruned RNN-T.
|
||||
|------------------------|------|------|---------------------------------------|
|
||||
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
|
||||
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
|
||||
| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 |
|
||||
| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 |
|
||||
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
|
||||
|
||||
Training command is:
|
||||
@ -73,6 +75,78 @@ for epoch in 29; do
|
||||
done
|
||||
```
|
||||
|
||||
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
|
||||
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
|
||||
please use the following command:
|
||||
|
||||
```bash
|
||||
# download pre-trained model
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
|
||||
|
||||
aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/
|
||||
|
||||
pushd ${aishell_exp}/exp
|
||||
ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt
|
||||
popd
|
||||
|
||||
# download RNN LM
|
||||
git lfs install
|
||||
git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm
|
||||
rnnlm_dir=icefall-aishell-rnn-lm
|
||||
|
||||
# RNNLM shallow fusion
|
||||
for lm_scale in $(seq 0.26 0.02 0.34); do
|
||||
python ./pruned_transducer_stateless3/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--lang-dir ${aishell_exp}/data/lang_char \
|
||||
--exp-dir ${aishell_exp}/exp \
|
||||
--use-averaged-model False \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir ${rnnlm_dir}/exp \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 2 \
|
||||
--lm-vocab-size 4336
|
||||
done
|
||||
|
||||
# RNNLM Low-order density ratio (LODR) with a 2-gram
|
||||
|
||||
cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt
|
||||
|
||||
for lm_scale in 0.48; do
|
||||
for LODR_scale in -0.28; do
|
||||
python ./pruned_transducer_stateless3/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--lang-dir ${aishell_exp}/data/lang_char \
|
||||
--exp-dir ${aishell_exp}/exp \
|
||||
--use-averaged-model False \
|
||||
--decoding-method modified_beam_search_LODR \
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir ${rnnlm_dir}/exp \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 2 \
|
||||
--lm-vocab-size 4336 \
|
||||
--tokens-ngram 2 \
|
||||
--backoff-id 4336 \
|
||||
--ngram-lm-scale $LODR_scale
|
||||
done
|
||||
done
|
||||
|
||||
```
|
||||
|
||||
Pretrained models, training logs, decoding logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
|
||||
|
0
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file → Executable file
0
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file → Executable file
1
egs/aishell/ASR/local/sort_lm_training_data.py
Symbolic link
1
egs/aishell/ASR/local/sort_lm_training_data.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/sort_lm_training_data.py
|
@ -230,12 +230,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
|
||||
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
||||
fi
|
||||
|
||||
|
||||
# training words
|
||||
./local/prepare_char_lm_training_data.py \
|
||||
--lang-char data/lang_char \
|
||||
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
||||
--lm-archive $out_dir/lm_data.pt
|
||||
|
||||
# valid words
|
||||
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
||||
@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
||||
--lm-archive $out_dir/lm_data_valid.pt
|
||||
|
||||
# test words
|
||||
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
||||
@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
--hidden-dim 512 \
|
||||
--num-layers 2 \
|
||||
--batch-size 400 \
|
||||
--exp-dir rnnlm_char/exp \
|
||||
--lm-data data/lm_training_char/sorted_lm_data.pt \
|
||||
--lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \
|
||||
--exp-dir rnnlm_char/exp_aishell1_small \
|
||||
--lm-data data/lm_char/sorted_lm_data_aishell1.pt \
|
||||
--lm-data-valid data/lm_char/sorted_lm_data_valid.pt \
|
||||
--vocab-size 4336 \
|
||||
--master-port 12345
|
||||
fi
|
||||
|
@ -54,6 +54,40 @@ Usage:
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
|
||||
(5) modified beam search (with LM shallow fusion)
|
||||
./pruned_transducer_stateless3/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
--lm-scale 0.3 \
|
||||
--lm-exp-dir /path/to/LM \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
|
||||
(6) modified beam search with LM shallow fusion + LODR
|
||||
./pruned_transducer_stateless3/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--max-duration 600 \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--decoding-method modified_beam_search_LODR \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
--lm-scale 0.48 \
|
||||
--lm-exp-dir /path/to/LM \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
--tokens-ngram 2 \
|
||||
--ngram-lm-scale -0.28 \
|
||||
"""
|
||||
|
||||
|
||||
@ -74,9 +108,12 @@ from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -212,6 +249,60 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-shallow-fusion",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Use neural network LM for shallow fusion.
|
||||
If you want to use LODR, you will also need to set this to true
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-type",
|
||||
type=str,
|
||||
default="rnn",
|
||||
help="Type of NN lm",
|
||||
choices=["rnn", "transformer"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="""The scale of the neural network LM
|
||||
Used only when `--use-shallow-fusion` is set to True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Token Ngram used for rescoring.
|
||||
Used only when the decoding method is
|
||||
modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""ID of the backoff symbol.
|
||||
Used only when the decoding method is
|
||||
modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -223,6 +314,9 @@ def decode_one_batch(
|
||||
token_table: k2.SymbolTable,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -287,6 +381,24 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LM=LM,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
hyp_tokens = modified_beam_search_LODR(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
else:
|
||||
hyp_tokens = []
|
||||
batch_size = encoder_out.size(0)
|
||||
@ -334,6 +446,9 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
token_table: k2.SymbolTable,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -379,6 +494,9 @@ def decode_dataset(
|
||||
token_table=token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -445,6 +563,7 @@ def save_results(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AsrDataModule.add_arguments(parser)
|
||||
LmScorer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
@ -458,6 +577,8 @@ def main():
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -479,6 +600,19 @@ def main():
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if "ngram" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
if params.use_shallow_fusion:
|
||||
if params.lm_type == "rnn":
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||
elif params.lm_type == "transformer":
|
||||
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||
|
||||
if "LODR" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -588,6 +722,35 @@ def main():
|
||||
else:
|
||||
decoding_graph = None
|
||||
|
||||
# only load N-gram LM when needed
|
||||
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||
lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
lm_filename,
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
ngram_lm_scale = params.ngram_lm_scale
|
||||
else:
|
||||
ngram_lm = None
|
||||
ngram_lm_scale = None
|
||||
|
||||
# only load the neural network LM if doing shallow fusion
|
||||
if params.use_shallow_fusion:
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
|
||||
else:
|
||||
LM = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -610,6 +773,9 @@ def main():
|
||||
model=model,
|
||||
token_table=lexicon.token_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -1 +0,0 @@
|
||||
/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1
|
@ -550,7 +550,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -561,7 +560,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
@ -1863,7 +1863,6 @@ def modified_beam_search_LODR(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
LODR_lm: NgramLm,
|
||||
LODR_lm_scale: float,
|
||||
LM: LmScorer,
|
||||
@ -1883,8 +1882,6 @@ def modified_beam_search_LODR(
|
||||
encoder_out_lens (torch.Tensor):
|
||||
A 1-D tensor of shape (N,), containing the number of
|
||||
valid frames in encoder_out before padding.
|
||||
sp:
|
||||
Sentence piece generator.
|
||||
LODR_lm:
|
||||
A low order n-gram LM, whose score will be subtracted during shallow fusion
|
||||
LODR_lm_scale:
|
||||
@ -1912,7 +1909,7 @@ def modified_beam_search_LODR(
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
sos_id = getattr(LM, "sos_id", 1)
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
@ -2137,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
LM: LmScorer,
|
||||
beam: int = 4,
|
||||
return_timestamps: bool = False,
|
||||
@ -2176,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion(
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
sos_id = getattr(LM, "sos_id", 1)
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
@ -675,7 +675,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -686,7 +685,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
@ -586,7 +586,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -597,7 +596,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
@ -533,7 +533,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -544,7 +543,6 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
|
@ -40,8 +40,8 @@ from tqdm import tqdm
|
||||
# and 'data()' is only supported in static graph mode. So if you
|
||||
# want to use this api, should call 'paddle.enable_static()' before
|
||||
# this api to enter static graph mode.
|
||||
paddle.enable_static()
|
||||
paddle.disable_signal_handler()
|
||||
# paddle.enable_static()
|
||||
# paddle.disable_signal_handler()
|
||||
jieba.enable_paddle()
|
||||
|
||||
|
||||
|
@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then
|
||||
log "Stage 18: Compile LG"
|
||||
python ./local/compile_lg.py --lang-dir $lang_char_dir
|
||||
fi
|
||||
|
||||
# prepare RNNLM data
|
||||
if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then
|
||||
log "Stage 19: Prepare LM training data"
|
||||
|
||||
log "Processing char based data"
|
||||
text_out_dir=data/lm_char
|
||||
|
||||
mkdir -p $text_out_dir
|
||||
|
||||
log "Genearating training text data"
|
||||
|
||||
if [ ! -f $text_out_dir/lm_data.pt ]; then
|
||||
./local/prepare_char_lm_training_data.py \
|
||||
--lang-char data/lang_char \
|
||||
--lm-data $lang_char_dir/text_words_segmentation \
|
||||
--lm-archive $text_out_dir/lm_data.pt
|
||||
fi
|
||||
|
||||
log "Generating DEV text data"
|
||||
# prepare validation text data
|
||||
if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then
|
||||
valid_text=${text_out_dir}/
|
||||
|
||||
gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $text_out_dir/valid_text
|
||||
|
||||
python3 ./local/text2segments.py \
|
||||
--num-process $nj \
|
||||
--input-file $text_out_dir/valid_text \
|
||||
--output-file $text_out_dir/valid_text_words_segmentation
|
||||
fi
|
||||
|
||||
./local/prepare_char_lm_training_data.py \
|
||||
--lang-char data/lang_char \
|
||||
--lm-data $text_out_dir/valid_text_words_segmentation \
|
||||
--lm-archive $text_out_dir/lm_data_valid.pt
|
||||
|
||||
# prepare TEST text data
|
||||
if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then
|
||||
log "Prepare text for test set."
|
||||
for test_set in TEST_MEETING TEST_NET; do
|
||||
gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text
|
||||
|
||||
python3 ./local/text2segments.py \
|
||||
--num-process $nj \
|
||||
--input-file $text_out_dir/${test_set}_text \
|
||||
--output-file $text_out_dir/${test_set}_text_words_segmentation
|
||||
done
|
||||
|
||||
cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation
|
||||
fi
|
||||
|
||||
./local/prepare_char_lm_training_data.py \
|
||||
--lang-char data/lang_char \
|
||||
--lm-data $text_out_dir/test_text_words_segmentation \
|
||||
--lm-archive $text_out_dir/lm_data_test.pt
|
||||
|
||||
fi
|
||||
|
||||
# sort RNNLM data
|
||||
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
|
||||
text_out_dir=data/lm_char
|
||||
|
||||
log "Sort lm data"
|
||||
|
||||
./local/sort_lm_training_data.py \
|
||||
--in-lm-data $text_out_dir/lm_data.pt \
|
||||
--out-lm-data $text_out_dir/sorted_lm_data.pt \
|
||||
--out-statistics $text_out_dir/statistics.txt
|
||||
|
||||
./local/sort_lm_training_data.py \
|
||||
--in-lm-data $text_out_dir/lm_data_valid.pt \
|
||||
--out-lm-data $text_out_dir/sorted_lm_data-valid.pt \
|
||||
--out-statistics $text_out_dir/statistics-valid.txt
|
||||
|
||||
./local/sort_lm_training_data.py \
|
||||
--in-lm-data $text_out_dir/lm_data_test.pt \
|
||||
--out-lm-data $text_out_dir/sorted_lm_data-test.pt \
|
||||
--out-statistics $text_out_dir/statistics-test.txt
|
||||
fi
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then
|
||||
log "Stage 21: Train RNN LM model"
|
||||
python ../../../icefall/rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
--world-size 2 \
|
||||
--num-epochs 20 \
|
||||
--use-fp16 0 \
|
||||
--embedding-dim 2048 \
|
||||
--hidden-dim 2048 \
|
||||
--num-layers 2 \
|
||||
--batch-size 400 \
|
||||
--exp-dir rnnlm_char/exp \
|
||||
--lm-data data/lm_char/sorted_lm_data.pt \
|
||||
--lm-data-valid data/lm_char/sorted_lm_data-valid.pt \
|
||||
--vocab-size 4336 \
|
||||
--master-port 12340
|
||||
fi
|
@ -2,6 +2,7 @@
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -91,6 +92,22 @@ When training with the L subset, the streaming usage:
|
||||
--causal-convolution 1 \
|
||||
--decode-chunk-size 16 \
|
||||
--left-context 64
|
||||
|
||||
(4) modified beam search with RNNLM shallow fusion
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--beam-size 4 \
|
||||
--lm-type rnn \
|
||||
--lm-scale 0.3 \
|
||||
--lm-exp-dir /path/to/LM \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
"""
|
||||
|
||||
|
||||
@ -111,9 +128,12 @@ from beam_search import (
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -224,6 +244,16 @@ def get_parser():
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
@ -277,6 +307,50 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-shallow-fusion",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Use neural network LM for shallow fusion.
|
||||
If you want to use LODR, you will also need to set this to true
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-type",
|
||||
type=str,
|
||||
default="rnn",
|
||||
help="Type of NN lm",
|
||||
choices=["rnn", "transformer"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="""The scale of the neural network LM
|
||||
Used only when `--use-shallow-fusion` is set to True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Token Ngram used for rescoring.
|
||||
Used only when the decoding method is
|
||||
modified_beam_search_ngram_rescoring, or LODR
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backoff-id",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""ID of the backoff symbol.
|
||||
Used only when the decoding method is
|
||||
modified_beam_search_ngram_rescoring""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -288,6 +362,9 @@ def decode_one_batch(
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -374,6 +451,28 @@ def decode_one_batch(
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LM=LM,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
hyp_tokens = modified_beam_search_LODR(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -419,6 +518,9 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
lexicon: Lexicon,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -432,6 +534,8 @@ def decode_dataset(
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
LM:
|
||||
A neural network LM, used during shallow fusion
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -449,7 +553,7 @@ def decode_dataset(
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 2
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -463,6 +567,9 @@ def decode_dataset(
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -524,6 +631,7 @@ def save_results(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||
LmScorer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
@ -535,6 +643,8 @@ def main():
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_lm_shallow_fusion",
|
||||
"modified_beam_search_LODR",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -549,6 +659,22 @@ def main():
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if "ngram" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
if params.use_shallow_fusion:
|
||||
if params.lm_type == "rnn":
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}"
|
||||
elif params.lm_type == "transformer":
|
||||
params.suffix += f"-transformer-lm-scale-{params.lm_scale}"
|
||||
|
||||
if "LODR" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
@ -558,6 +684,7 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
@ -652,6 +779,37 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
# only load N-gram LM when needed
|
||||
if "ngram" in params.decoding_method or "LODR" in params.decoding_method:
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
ngram_lm_scale = params.ngram_lm_scale
|
||||
else:
|
||||
ngram_lm = None
|
||||
ngram_lm_scale = None
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
# only load the neural network LM if doing shallow fusion
|
||||
if params.use_shallow_fusion:
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in LM.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
else:
|
||||
LM = None
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
@ -684,6 +842,9 @@ def main():
|
||||
model=model,
|
||||
lexicon=lexicon,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
save_results(
|
||||
params=params,
|
||||
|
@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module):
|
||||
def add_arguments(cls, parser):
|
||||
# LM general arguments
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
"--lm-vocab-size",
|
||||
type=int,
|
||||
default=500,
|
||||
)
|
||||
|
@ -33,7 +33,7 @@ import torch
|
||||
from dataset import get_dataloader
|
||||
from model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ def get_parser():
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
@ -58,6 +59,16 @@ def get_parser():
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -154,7 +165,14 @@ def main():
|
||||
|
||||
params = AttributeDict(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-ppl/")
|
||||
if params.iter > 0:
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}"
|
||||
)
|
||||
else:
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}"
|
||||
)
|
||||
logging.info("Computing perplexity started")
|
||||
logging.info(params)
|
||||
|
||||
@ -173,19 +191,39 @@ def main():
|
||||
tie_weights=params.tie_weights,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
if i >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
num_param_requires_grad = sum(
|
||||
|
@ -25,7 +25,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, load_averaged_model, str2bool
|
||||
|
||||
|
||||
@ -51,6 +51,16 @@ def get_parser():
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
@ -133,11 +143,36 @@ def main():
|
||||
|
||||
model.to(device)
|
||||
|
||||
if params.avg == 1:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
model = load_averaged_model(
|
||||
params.exp_dir, model, params.epoch, params.avg, device
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
|
@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
@ -178,6 +179,33 @@ def get_parser():
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=1e-3,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sent-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Maximum number of tokens in a sentence. This is used
|
||||
to adjust batch-size dynamically""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -190,16 +218,15 @@ def get_params() -> AttributeDict:
|
||||
"sos_id": 1,
|
||||
"eos_id": 1,
|
||||
"blank_id": 0,
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-6,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 200,
|
||||
"log_interval": 100,
|
||||
"reset_interval": 2000,
|
||||
"valid_interval": 5000,
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -382,6 +409,7 @@ def train_one_epoch(
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
@ -430,6 +458,19 @@ def train_one_epoch(
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
# Note: "frames" here means "num_tokens"
|
||||
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
||||
@ -580,6 +621,7 @@ def run(rank, world_size, args):
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
|
Loading…
x
Reference in New Issue
Block a user