mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update decode.py and train.py to use periodically averaged models.
This commit is contained in:
parent
7b786ce0b9
commit
2ce48a2c21
@ -20,40 +20,40 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -502,7 +502,7 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -78,6 +79,7 @@ from train import add_model_arguments, get_params, get_transducer_model
|
|||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
@ -85,6 +87,7 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -97,9 +100,9 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=30,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 0.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -122,6 +125,17 @@ def get_parser():
|
|||||||
"'--epoch' and '--iter'",
|
"'--epoch' and '--iter'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -238,7 +252,7 @@ def decode_one_batch(
|
|||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
|
|
||||||
@ -475,6 +489,9 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
|
||||||
@ -497,38 +514,85 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.iter > 0:
|
if not params.use_averaged_model:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
if params.iter > 0:
|
||||||
: params.avg
|
filenames = find_checkpoints(
|
||||||
]
|
params.exp_dir, iteration=-params.iter
|
||||||
if len(filenames) == 0:
|
)[: params.avg]
|
||||||
raise ValueError(
|
if len(filenames) == 0:
|
||||||
f"No checkpoints found for"
|
raise ValueError(
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
f"No checkpoints found for"
|
||||||
)
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
elif len(filenames) < params.avg:
|
)
|
||||||
raise ValueError(
|
elif len(filenames) < params.avg:
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
raise ValueError(
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
)
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
logging.info(f"averaging {filenames}")
|
)
|
||||||
model.to(device)
|
logging.info(f"averaging {filenames}")
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.to(device)
|
||||||
elif params.avg == 1:
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
if params.iter > 0:
|
||||||
filenames = []
|
filenames = find_checkpoints(
|
||||||
for i in range(start, params.epoch + 1):
|
params.exp_dir, iteration=-params.iter
|
||||||
if start >= 0:
|
)[: params.avg + 1]
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
if len(filenames) == 0:
|
||||||
logging.info(f"averaging {filenames}")
|
raise ValueError(
|
||||||
model.to(device)
|
f"No checkpoints found for"
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang
|
# Wei Kang,
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo,)
|
||||||
|
# Zengwei Yao)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -24,7 +25,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
./pruned_transducer_stateless5/train.py \
|
./pruned_transducer_stateless5/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 1 \
|
||||||
--exp-dir pruned_transducer_stateless5/exp \
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
@ -34,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
./pruned_transducer_stateless5/train.py \
|
./pruned_transducer_stateless5/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless5/exp \
|
--exp-dir pruned_transducer_stateless5/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
@ -44,6 +45,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -73,7 +75,10 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import (
|
||||||
|
save_checkpoint_with_global_batch_idx,
|
||||||
|
update_averaged_model,
|
||||||
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
@ -166,10 +171,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start-epoch",
|
"--start-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=1,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from this epoch. It should be positive.
|
||||||
If it is positive, it will load checkpoint from
|
If larger than 1, it will load checkpoint from
|
||||||
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
|
exp-dir/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -282,7 +287,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=8000,
|
default=4000,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -295,7 +300,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--keep-last-k",
|
"--keep-last-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=30,
|
||||||
help="""Only keep this number of checkpoints on disk.
|
help="""Only keep this number of checkpoints on disk.
|
||||||
For instance, if it is 3, there are only 3 checkpoints
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||||
@ -303,6 +308,19 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--average-period",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Update the averaged model, namely `model_avg`, after processing
|
||||||
|
this number of batches. `model_avg` is a separate version of model,
|
||||||
|
in which each floating-point parameter is the average of all the
|
||||||
|
parameters from the start of training. Each time we take the average,
|
||||||
|
we do: `model_avg = model * (average_period / batch_idx_train) +
|
||||||
|
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-fp16",
|
"--use-fp16",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -434,6 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
model_avg: nn.Module = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
@ -441,7 +460,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
If params.start_batch is positive, it will load the checkpoint from
|
If params.start_batch is positive, it will load the checkpoint from
|
||||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||||
params.start_epoch is positive, it will load the checkpoint from
|
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
`params.start_epoch - 1`.
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
@ -453,6 +472,8 @@ def load_checkpoint_if_available(
|
|||||||
The return value of :func:`get_params`.
|
The return value of :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer that we are using.
|
The optimizer that we are using.
|
||||||
scheduler:
|
scheduler:
|
||||||
@ -462,7 +483,7 @@ def load_checkpoint_if_available(
|
|||||||
"""
|
"""
|
||||||
if params.start_batch > 0:
|
if params.start_batch > 0:
|
||||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||||
elif params.start_epoch > 0:
|
elif params.start_epoch > 1:
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -472,6 +493,7 @@ def load_checkpoint_if_available(
|
|||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
@ -498,7 +520,8 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
@ -512,6 +535,8 @@ def save_checkpoint(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The training model.
|
The training model.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer used in the training.
|
The optimizer used in the training.
|
||||||
sampler:
|
sampler:
|
||||||
@ -525,6 +550,7 @@ def save_checkpoint(
|
|||||||
save_checkpoint_impl(
|
save_checkpoint_impl(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -544,7 +570,7 @@ def save_checkpoint(
|
|||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
@ -568,7 +594,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -624,7 +654,7 @@ def compute_loss(
|
|||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -658,13 +688,14 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -690,6 +721,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -739,6 +772,17 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
rank == 0
|
||||||
|
and params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.average_period == 0
|
||||||
|
):
|
||||||
|
update_averaged_model(
|
||||||
|
params=params,
|
||||||
|
model_cur=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
@ -748,6 +792,7 @@ def train_one_epoch(
|
|||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -855,13 +900,21 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
assert params.save_every_n >= params.average_period
|
||||||
|
model_avg: Optional[nn.Module] = None
|
||||||
|
if rank == 0:
|
||||||
|
# model_avg is only used with rank 0
|
||||||
|
model_avg = copy.deepcopy(model)
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
checkpoints = load_checkpoint_if_available(
|
||||||
|
params=params, model=model, model_avg=model_avg
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
model.device = device
|
|
||||||
|
|
||||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||||
|
|
||||||
@ -934,10 +987,10 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch)
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
@ -947,6 +1000,7 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
@ -965,6 +1019,7 @@ def run(rank, world_size, args):
|
|||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
@ -1012,7 +1067,7 @@ def display_and_save_batch(
|
|||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
@ -1021,7 +1076,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||||
)
|
)
|
||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user