mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add scripts for training and perplexity computation.
This commit is contained in:
parent
42dcd53361
commit
2213154c69
@ -72,6 +72,16 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
--bpe-model $out_dir/bpe.model \
|
||||
--lm-data $dl_dir/ptb.train.txt \
|
||||
--lm-archive $out_dir/lm_data.pt
|
||||
|
||||
./local/prepare_lm_training_data.py \
|
||||
--bpe-model $out_dir/bpe.model \
|
||||
--lm-data $dl_dir/ptb.valid.txt \
|
||||
--lm-archive $out_dir/lm_data-valid.pt
|
||||
|
||||
./local/prepare_lm_training_data.py \
|
||||
--bpe-model $out_dir/bpe.model \
|
||||
--lm-data $dl_dir/ptb.test.txt \
|
||||
--lm-archive $out_dir/lm_data-test.pt
|
||||
done
|
||||
fi
|
||||
|
||||
@ -91,5 +101,15 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--in-lm-data $out_dir/lm_data.pt \
|
||||
--out-lm-data $out_dir/sorted_lm_data.pt \
|
||||
--out-statistics $out_dir/statistics.txt
|
||||
|
||||
./local/sort_lm_training_data.py \
|
||||
--in-lm-data $out_dir/lm_data-valid.pt \
|
||||
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
|
||||
--out-statistics $out_dir/statistics-valid.txt
|
||||
|
||||
./local/sort_lm_training_data.py \
|
||||
--in-lm-data $out_dir/lm_data-test.pt \
|
||||
--out-lm-data $out_dir/sorted_lm_data-test.pt \
|
||||
--out-statistics $out_dir/statistics-test.txt
|
||||
done
|
||||
fi
|
||||
|
228
egs/ptb/LM/rnn_lm/compute_perplexity.py
Executable file
228
egs/ptb/LM/rnn_lm/compute_perplexity.py
Executable file
@ -0,0 +1,228 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./rnn_lm/compute_perplexity.py \
|
||||
--epoch 4 \
|
||||
--avg 2 \
|
||||
--lm-data ./data/bpe_500/sorted_lm_data-test.pt
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from rnn_lm.dataset import get_dataloader
|
||||
from rnn_lm.model import RnnLmModel
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=49,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-data",
|
||||
type=str,
|
||||
help="Path to the LM test data for computing perplexity",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Vocabulary size of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Embedding dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Hidden dim of the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sent-len",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sos-id",
|
||||
type=int,
|
||||
default=1,
|
||||
help="SOS ID",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eos-id",
|
||||
type=int,
|
||||
default=1,
|
||||
help="EOS ID",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Blank ID",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lm_data = Path(args.lm_data)
|
||||
|
||||
params = AttributeDict(vars(args))
|
||||
print(params)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-ppl/")
|
||||
logging.info("Computing perplexity started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info("About to create model")
|
||||
model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
hidden_dim=params.hidden_dim,
|
||||
num_layers=params.num_layers,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
num_param_requires_grad = sum(
|
||||
[p.numel() for p in model.parameters() if p.requires_grad]
|
||||
)
|
||||
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
logging.info(
|
||||
f"Number of model parameters (requires_grad): "
|
||||
f"{num_param_requires_grad} "
|
||||
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
|
||||
)
|
||||
|
||||
logging.info(f"Loading LM test data from {params.lm_data}")
|
||||
test_dl = get_dataloader(
|
||||
filename=params.lm_data,
|
||||
is_distributed=False,
|
||||
params=params,
|
||||
)
|
||||
|
||||
tot_loss = 0.0
|
||||
num_tokens = 0
|
||||
num_sentences = 0
|
||||
for batch_idx, batch in enumerate(test_dl):
|
||||
x, y, sentence_lengths = batch
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
sentence_lengths = sentence_lengths.to(device)
|
||||
|
||||
nll = model(x, y, sentence_lengths)
|
||||
loss = nll.sum().cpu().item()
|
||||
|
||||
tot_loss += loss
|
||||
num_tokens += sentence_lengths.sum().cpu().item()
|
||||
num_sentences += x.size(0)
|
||||
|
||||
ppl = math.exp(tot_loss / num_tokens)
|
||||
logging.info(
|
||||
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
|
||||
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
|
||||
)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -19,6 +19,8 @@ from typing import List, Tuple
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
class LmDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
@ -233,7 +235,7 @@ class LmDatasetCollate:
|
||||
for a sentence starting with `self.sos_id`. It is padded to
|
||||
the max sentence length with `self.blank_id`.
|
||||
|
||||
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
|
||||
- y, a 2-D tensor of dtype torch.int32; each row contains tokens
|
||||
for a sentence ending with `self.eos_id` before padding.
|
||||
Then it is padded to the max sentence length with
|
||||
`self.blank_id`.
|
||||
@ -257,4 +259,58 @@ class LmDatasetCollate:
|
||||
)
|
||||
sentence_token_lengths += 1 # plus 1 since we added a SOS
|
||||
|
||||
return x, y, sentence_token_lengths
|
||||
return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
|
||||
|
||||
|
||||
def get_dataloader(
|
||||
filename: str,
|
||||
is_distributed: bool,
|
||||
params: AttributeDict,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
"""Get dataloader for LM training.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the file containing LM data. The file is assumed to
|
||||
be generated by `../local/sort_lm_training_data.py`.
|
||||
is_distributed:
|
||||
True if using DDP training. False otherwise.
|
||||
params:
|
||||
Set `get_params()` from `rnn_lm/train.py`
|
||||
Returns:
|
||||
Return a dataloader containing the LM data.
|
||||
"""
|
||||
lm_data = torch.load(filename)
|
||||
|
||||
words = lm_data["words"]
|
||||
sentences = lm_data["sentences"]
|
||||
sentence_lengths = lm_data["sentence_lengths"]
|
||||
|
||||
dataset = LmDataset(
|
||||
sentences=sentences,
|
||||
words=words,
|
||||
sentence_lengths=sentence_lengths,
|
||||
max_sent_len=params.max_sent_len,
|
||||
batch_size=params.batch_size,
|
||||
)
|
||||
if is_distributed:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, shuffle=True, drop_last=False
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
collate_fn = LmDatasetCollate(
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
blank_id=params.blank_id,
|
||||
)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler,
|
||||
shuffle=sampler is None,
|
||||
)
|
||||
return dataloader
|
||||
|
145
egs/ptb/LM/rnn_lm/model.py
Normal file
145
egs/ptb/LM/rnn_lm/model.py
Normal file
@ -0,0 +1,145 @@
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
filled with `False`.
|
||||
|
||||
>>> lengths = torch.tensor([1, 3, 2, 5])
|
||||
>>> make_pad_mask(lengths)
|
||||
tensor([[False, True, True, True, True],
|
||||
[False, False, False, True, True],
|
||||
[False, False, True, True, True],
|
||||
[False, False, False, False, False]])
|
||||
"""
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
|
||||
max_len = lengths.max()
|
||||
n = lengths.size(0)
|
||||
|
||||
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
||||
|
||||
return expaned_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
class RnnLmModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_dim: int,
|
||||
hidden_dim: int,
|
||||
num_layers: int,
|
||||
tie_weights: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Vocabulary size of BPE model.
|
||||
embedding_dim:
|
||||
Input embedding dimension.
|
||||
hidden_dim:
|
||||
Hidden dimension of RNN layers.
|
||||
num_layers:
|
||||
Number of RNN layers.
|
||||
tie_weights:
|
||||
True to share the weights between the input embedding layer and the
|
||||
last output linear layer. See https://arxiv.org/abs/1608.05859
|
||||
and https://arxiv.org/abs/1611.01462
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_embedding = torch.nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
|
||||
self.rnn = torch.nn.LSTM(
|
||||
input_size=embedding_dim,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.output_linear = torch.nn.Linear(
|
||||
in_features=hidden_dim, out_features=vocab_size
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
if tie_weights:
|
||||
logging.info("Tying weights")
|
||||
assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim)
|
||||
self.output_linear.weight = self.input_embedding.weight
|
||||
else:
|
||||
logging.info("Not tying weights")
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D tensor with shape (N, L). Each row
|
||||
contains token IDs for a sentence and starts with the SOS token.
|
||||
y:
|
||||
A shifted version of `x` and with EOS appended.
|
||||
lengths:
|
||||
A 1-D tensor of shape (N,). It contains the sentence lengths
|
||||
before padding.
|
||||
Returns:
|
||||
Return a 2-D tensor of shape (N, L) containing negative log-likelihood
|
||||
loss values. Note: Loss values for padding positions are set to 0.
|
||||
"""
|
||||
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
assert x.shape == y.shape, (x.shape, y.shape)
|
||||
|
||||
batch_size = x.size(0)
|
||||
assert lengths.size(0) == batch_size, (lengths.size(0), batch_size)
|
||||
|
||||
# embedding is of shape (N, L, embedding_dim)
|
||||
embedding = self.input_embedding(x)
|
||||
|
||||
# Note: We use batch_first==True
|
||||
rnn_out, _ = self.rnn(embedding)
|
||||
logits = self.output_linear(rnn_out)
|
||||
|
||||
# Note: No need to use `log_softmax()` here
|
||||
# since F.cross_entropy() expects unnormalized probabilities
|
||||
|
||||
# nll_loss is of shape (N*L,)
|
||||
# nll -> negative log-likelihood
|
||||
nll_loss = F.cross_entropy(
|
||||
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
|
||||
)
|
||||
# Set loss values for padding positions to 0
|
||||
mask = make_pad_mask(lengths).reshape(-1)
|
||||
nll_loss.masked_fill_(mask, 0)
|
||||
|
||||
nll_loss = nll_loss.reshape(batch_size, -1)
|
||||
|
||||
return nll_loss
|
84
egs/ptb/LM/rnn_lm/test_model.py
Executable file
84
egs/ptb/LM/rnn_lm/test_model.py
Executable file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from rnn_lm.model import RnnLmModel, make_pad_mask
|
||||
|
||||
|
||||
def test_makd_pad_mask():
|
||||
lengths = torch.tensor([1, 3, 2])
|
||||
mask = make_pad_mask(lengths)
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[False, True, True],
|
||||
[False, False, False],
|
||||
[False, False, True],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected))
|
||||
assert (~expected).sum() == lengths.sum()
|
||||
|
||||
|
||||
def test_rnn_lm_model():
|
||||
vocab_size = 4
|
||||
model = RnnLmModel(
|
||||
vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2
|
||||
)
|
||||
x = torch.tensor(
|
||||
[
|
||||
[1, 3, 2, 2],
|
||||
[1, 2, 2, 0],
|
||||
[1, 2, 0, 0],
|
||||
]
|
||||
)
|
||||
y = torch.tensor(
|
||||
[
|
||||
[3, 2, 2, 1],
|
||||
[2, 2, 1, 0],
|
||||
[2, 1, 0, 0],
|
||||
]
|
||||
)
|
||||
lengths = torch.tensor([4, 3, 2])
|
||||
nll_loss = model(x, y, lengths)
|
||||
print(nll_loss)
|
||||
"""
|
||||
tensor([[1.1180, 1.3059, 1.2426, 1.7773],
|
||||
[1.4231, 1.2783, 1.7321, 0.0000],
|
||||
[1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
|
||||
"""
|
||||
|
||||
|
||||
def test_rnn_lm_model_tie_weights():
|
||||
model = RnnLmModel(
|
||||
vocab_size=10,
|
||||
embedding_dim=10,
|
||||
hidden_dim=10,
|
||||
num_layers=2,
|
||||
tie_weights=True,
|
||||
)
|
||||
assert model.input_embedding.weight is model.output_linear.weight
|
||||
|
||||
|
||||
def main():
|
||||
test_makd_pad_mask()
|
||||
test_rnn_lm_model()
|
||||
test_rnn_lm_model_tie_weights()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20211122)
|
||||
main()
|
554
egs/ptb/LM/rnn_lm/train.py
Executable file
554
egs/ptb/LM/rnn_lm/train.py
Executable file
@ -0,0 +1,554 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./rnn_lm/train.py \
|
||||
--world-size 2 \
|
||||
--start-epoch 4
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from lhotse.utils import fix_random_seed
|
||||
from rnn_lm.dataset import get_dataloader
|
||||
from rnn_lm.model import RnnLmModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
exp_dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, logs, etc, are saved
|
||||
""",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters."""
|
||||
|
||||
params = AttributeDict(
|
||||
{
|
||||
# LM training/validation data
|
||||
"lm_data": "data/bpe_500/sorted_lm_data.pt",
|
||||
"lm_data_valid": "data/bpe_500/sorted_lm_data-valid.pt",
|
||||
"batch_size": 50,
|
||||
"max_sent_len": 200,
|
||||
"sos_id": 1,
|
||||
"eos_id": 1,
|
||||
"blank_id": 0,
|
||||
# model related
|
||||
#
|
||||
# vocab size of the BPE model
|
||||
"vocab_size": 500,
|
||||
"embedding_dim": 2048,
|
||||
"hidden_dim": 2048,
|
||||
"num_layers": 4,
|
||||
#
|
||||
"lr": 1e-3,
|
||||
"weight_decay": 1e-6,
|
||||
#
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 300,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
logging.info(f"Loading checkpoint: {filename}")
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
model: nn.Module,
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
sentence_lengths: torch.Tensor,
|
||||
is_training: bool,
|
||||
) -> Tuple[torch.Tensor, MetricsTracker]:
|
||||
"""Compute the negative log-likelihood loss given a model and its input.
|
||||
Args:
|
||||
model:
|
||||
The NN model, e.g., RnnLmModel.
|
||||
x:
|
||||
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
|
||||
each row starts with SOS ID.
|
||||
y:
|
||||
A 2-D tensor. Each row is a shifted version of the corresponding row
|
||||
in `x` but ends with an EOS ID (before padding).
|
||||
sentence_lengths:
|
||||
A 1-D tensor containing number of tokens of each sentence
|
||||
before padding.
|
||||
is_training:
|
||||
True for training. False for validation.
|
||||
"""
|
||||
with torch.set_grad_enabled(is_training):
|
||||
device = model.device
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
sentence_lengths = sentence_lengths.to(device)
|
||||
|
||||
nll = model(x, y, sentence_lengths)
|
||||
loss = nll.sum()
|
||||
|
||||
num_tokens = sentence_lengths.sum().item()
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
# Note: Due to how MetricsTracker() is designed,
|
||||
# we use "frames" instead of "num_tokens" as a key here
|
||||
loss_info["frames"] = num_tokens
|
||||
loss_info["loss"] = loss.detach().item()
|
||||
return loss, loss_info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
x, y, sentence_lengths = batch
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
model=model,
|
||||
x=x,
|
||||
y=y,
|
||||
sentence_lengths=sentence_lengths,
|
||||
is_training=False,
|
||||
)
|
||||
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all sentences is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
x, y, sentence_lengths = batch
|
||||
batch_size = x.size(0)
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
model=model,
|
||||
x=x,
|
||||
y=y,
|
||||
sentence_lengths=sentence_lengths,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
# Note: "frames" here means "num_tokens"
|
||||
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
|
||||
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
|
||||
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/current_ppl", this_batch_ppl, params.batch_idx_train
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_ppl", tot_ppl, params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
|
||||
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
|
||||
f"ppl: {valid_ppl}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_ppl", valid_ppl, params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info("About to create model")
|
||||
model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.embedding_dim,
|
||||
hidden_dim=params.hidden_dim,
|
||||
num_layers=params.num_layers,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
model.device = device
|
||||
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
if checkpoints:
|
||||
logging.info("Load optimizer state_dict from checkpoint")
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
logging.info(f"Loading LM training data from {params.lm_data}")
|
||||
train_dl = get_dataloader(
|
||||
filename=params.lm_data,
|
||||
is_distributed=world_size > 1,
|
||||
params=params,
|
||||
)
|
||||
|
||||
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
|
||||
valid_dl = get_dataloader(
|
||||
filename=params.lm_data_valid,
|
||||
is_distributed=world_size > 1,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Note: No learning rate scheduler is used here
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
if world_size > 1:
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user