further simplification

This commit is contained in:
zr_jin 2024-10-22 17:22:37 +08:00
parent 68a4b6d090
commit c2056a5ab5
25 changed files with 0 additions and 4816 deletions

View File

@ -1,254 +0,0 @@
# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import torch
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import AttributeDict, str2bool
class LmScorer(torch.nn.Module):
"""This is a wrapper for NN LMs
The language models supported include:
RNN,
Transformer
"""
def __init__(
self,
lm_type: str,
params: AttributeDict,
device,
lm_scale: float = 0.3,
):
super(LmScorer, self).__init__()
assert lm_type in ["rnn", "transformer"], f"{lm_type} is not supported"
self.lm_type = lm_type
self.lm = self.get_lm(lm_type, device, params)
self.lm_scale = lm_scale
self.params = params
@classmethod
def add_arguments(cls, parser):
# LM general arguments
parser.add_argument(
"--lm-vocab-size",
type=int,
default=500,
)
parser.add_argument(
"--lm-epoch",
type=int,
default=7,
help="""Which epoch to be used
""",
)
parser.add_argument(
"--lm-avg",
type=int,
default=1,
help="""Number of checkpoints to be averaged
""",
)
parser.add_argument("--lm-exp-dir", type=str, help="Path to LM experiments")
# Now RNNLM related arguments
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
# Now transformers
parser.add_argument(
"--transformer-lm-exp-dir", type=str, help="Directory of transformer LM exp"
)
parser.add_argument(
"--transformer-lm-dim-feedforward",
type=int,
default=2048,
help="Dimension of FFW module in transformer",
)
parser.add_argument(
"--transformer-lm-encoder-dim",
type=int,
default=768,
help="Encoder dimension of transformer",
)
parser.add_argument(
"--transformer-lm-embedding-dim",
type=int,
default=768,
help="Input embedding dimension of transformer",
)
parser.add_argument(
"--transformer-lm-nhead",
type=int,
default=8,
help="Number of attention heads in transformer",
)
parser.add_argument(
"--transformer-lm-num-layers",
type=int,
default=16,
help="Number of encoder layers in transformer",
)
parser.add_argument(
"--transformer-lm-tie-weights",
type=str2bool,
default=True,
help="If tie weights in transformer LM",
)
def get_lm(self, lm_type: str, device, params: AttributeDict) -> torch.nn.Module:
"""Return the neural network LM
Args:
lm_type (str): Type name of NN LM
"""
if lm_type == "rnn":
model = RnnLmModel(
vocab_size=params.lm_vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.lm_avg == 1:
load_checkpoint(
f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model
)
model.to(device)
else:
start = params.lm_epoch - params.lm_avg + 1
filenames = []
for i in range(start, params.lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif lm_type == "transformer":
model = TransformerLM(
vocab_size=params.lm_vocab_size,
d_model=params.transformer_lm_encoder_dim,
embedding_dim=params.transformer_lm_embedding_dim,
dim_feedforward=params.transformer_lm_dim_feedforward,
nhead=params.transformer_lm_nhead,
num_layers=params.transformer_lm_num_layers,
tie_weights=params.transformer_lm_tie_weights,
params=params,
)
if params.lm_avg == 1:
load_checkpoint(
f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model
)
model.to(device)
else:
start = params.lm_epoch - params.lm_avg + 1
filenames = []
for i in range(start, params.lm_epoch + 1):
if start >= 0:
filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
raise NotImplementedError()
return model
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
"""Score the input and return the prediction
This requires the lm to have the method `score_token`
Args:
x (torch.Tensor): Input tokens
x_lens (torch.Tensor): Length of the input tokens
state (optional): LM states
"""
return self.lm.score_token(x, x_lens, state)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
LmScorer.add_arguments(parser)
args = parser.parse_args()
params = AttributeDict()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
Scorer = LmScorer(params=params, device=device)
Scorer.eval()
x = (
torch.tensor([[1, 4, 19, 256, 77], [1, 4, 19, 256, 77]])
.to(device)
.to(torch.int64)
)
x_lens = torch.tensor([5, 5]).to(device)
state = None
score, state = Scorer.score(x, x_lens)
print(score.shape)
print(score[0])
print(score[1])

View File

@ -1 +0,0 @@
icefall-librispeech-rnn-lm

View File

@ -1,128 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation
"""
Usage:
./check-onnx-streaming.py \
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
--onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx
Note: You can download pre-trained models from
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
"""
import argparse
import logging
from typing import Tuple
import onnxruntime as ort
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx",
required=True,
type=str,
help="Path to the onnx model",
)
return parser
class OnnxModel:
def __init__(self, filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.model = ort.InferenceSession(
filename,
sess_options=session_opts,
)
meta_data = self.model.get_modelmeta().custom_metadata_map
self.sos_id = int(meta_data["sos_id"])
self.eos_id = int(meta_data["eos_id"])
self.vocab_size = int(meta_data["vocab_size"])
self.num_layers = int(meta_data["num_layers"])
self.hidden_size = int(meta_data["hidden_size"])
print(meta_data)
def __call__(
self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
out = self.model.run(
[
self.model.get_outputs()[0].name,
self.model.get_outputs()[1].name,
self.model.get_outputs()[2].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: y.numpy(),
self.model.get_inputs()[2].name: h0.numpy(),
self.model.get_inputs()[3].name: c0.numpy(),
},
)
return (
torch.from_numpy(out[0]),
torch.from_numpy(out[1]),
torch.from_numpy(out[2]),
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_model = torch.jit.load(args.jit).cpu()
onnx_model = OnnxModel(args.onnx)
N = torch.arange(1, 5).tolist()
num_layers = onnx_model.num_layers
hidden_size = onnx_model.hidden_size
for n in N:
L = torch.randint(low=1, high=100, size=(1,)).item()
x = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
h0 = torch.rand(num_layers, n, hidden_size)
c0 = torch.rand(num_layers, n, hidden_size)
torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
for torch_v, onnx_v in zip(
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
):
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
torch_v.shape,
onnx_v.shape,
(torch_v - onnx_v).abs().max(),
)
print(n, L, torch_v.sum(), onnx_v.sum())
if __name__ == "__main__":
torch.manual_seed(20230423)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,119 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation
"""
Usage:
./check-onnx.py \
--jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \
--onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx
Note: You can download pre-trained models from
https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
"""
import argparse
import logging
import onnxruntime as ort
import torch
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx",
required=True,
type=str,
help="Path to the onnx model",
)
return parser
class OnnxModel:
def __init__(self, filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.model = ort.InferenceSession(
filename,
sess_options=session_opts,
)
meta_data = self.model.get_modelmeta().custom_metadata_map
self.sos_id = int(meta_data["sos_id"])
self.eos_id = int(meta_data["eos_id"])
self.vocab_size = int(meta_data["vocab_size"])
print(meta_data)
def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lens.numpy(),
},
)
return torch.from_numpy(out[0])
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
torch_model = torch.jit.load(args.jit).cpu()
onnx_model = OnnxModel(args.onnx)
N = torch.arange(1, 5).tolist()
for n in N:
L = torch.randint(low=1, high=100, size=(1,)).item()
x = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
x_lens = torch.full((n,), fill_value=L, dtype=torch.int64)
if n > 1:
x_lens[0] = L // 2 + 1
sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1)
sos_x = torch.cat([sos, x], dim=1)
pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1)
x_eos = torch.cat([x, pad_col], dim=1)
row_index = torch.arange(0, n, dtype=x.dtype)
x_eos[row_index, x_lens] = onnx_model.eos_id
torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1)
onnx_nll = onnx_model(x, x_lens)
# Note: For int8 models, the differences may be quite large,
# e.g., within 0.9
assert torch.allclose(torch_nll, onnx_nll), (
torch_nll,
onnx_nll,
)
print(n, L, torch_nll, onnx_nll)
if __name__ == "__main__":
torch.manual_seed(20230420)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,275 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/compute_perplexity.py \
--epoch 4 \
--avg 2 \
--lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt
"""
import argparse
import logging
import math
from pathlib import Path
import torch
from dataset import get_dataloader
from model import RnnLmModel
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=49,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="The experiment dir",
)
parser.add_argument(
"--lm-data",
type=str,
help="Path to the LM test data for computing perplexity",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of RNN layers the model",
)
parser.add_argument(
"--max-sent-len",
type=int,
default=100,
help="Number of RNN layers the model",
)
parser.add_argument(
"--sos-id",
type=int,
default=1,
help="SOS ID",
)
parser.add_argument(
"--eos-id",
type=int,
default=1,
help="EOS ID",
)
parser.add_argument(
"--blank-id",
type=int,
default=0,
help="Blank ID",
)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_data = Path(args.lm_data)
params = AttributeDict(vars(args))
if params.iter > 0:
setup_logger(
f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}"
)
else:
setup_logger(
f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}"
)
logging.info("Computing perplexity started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(
[p.numel() for p in model.parameters() if p.requires_grad]
)
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Number of model parameters (requires_grad): "
f"{num_param_requires_grad} "
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
)
logging.info(f"Loading LM test data from {params.lm_data}")
test_dl = get_dataloader(
filename=params.lm_data,
is_distributed=False,
params=params,
)
tot_loss = 0.0
num_tokens = 0
num_sentences = 0
for batch_idx, batch in enumerate(test_dl):
x, y, sentence_lengths = batch
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum().cpu().item()
tot_loss += loss
num_tokens += sentence_lengths.sum().cpu().item()
num_sentences += x.size(0)
ppl = math.exp(tot_loss / num_tokens)
logging.info(
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,214 +0,0 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import k2
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from icefall.utils import AttributeDict, add_eos, add_sos
class LmDataset(torch.utils.data.Dataset):
def __init__(
self,
sentences: k2.RaggedTensor,
words: k2.RaggedTensor,
sentence_lengths: torch.Tensor,
max_sent_len: int,
batch_size: int,
):
"""
Args:
sentences:
A ragged tensor of dtype torch.int32 with 2 axes [sentence][word].
words:
A ragged tensor of dtype torch.int32 with 2 axes [word][token].
sentence_lengths:
A 1-D tensor of dtype torch.int32 containing number of tokens
of each sentence.
max_sent_len:
Maximum sentence length. It is used to change the batch size
dynamically. In general, we try to keep the product of
"max_sent_len in a batch" and "num_of_sent in a batch" being
a constant.
batch_size:
The expected batch size. It is changed dynamically according
to the "max_sent_len".
See `../local/prepare_lm_training_data.py` for how `sentences` and
`words` are generated. We assume that `sentences` are sorted by length.
See `../local/sort_lm_training_data.py`.
"""
super().__init__()
self.sentences = sentences
self.words = words
sentence_lengths = sentence_lengths.tolist()
assert batch_size > 0, batch_size
assert max_sent_len > 1, max_sent_len
batch_indexes = []
num_sentences = sentences.dim0
cur = 0
while cur < num_sentences:
sz = sentence_lengths[cur] // max_sent_len + 1
# Assume the current sentence has 3 * max_sent_len tokens,
# in the worst case, the subsequent sentences also have
# this number of tokens, we should reduce the batch size
# so that this batch will not contain too many tokens
actual_batch_size = batch_size // sz + 1
actual_batch_size = min(actual_batch_size, batch_size)
end = cur + actual_batch_size
end = min(end, num_sentences)
this_batch_indexes = torch.arange(cur, end).tolist()
batch_indexes.append(this_batch_indexes)
cur = end
assert batch_indexes[-1][-1] == num_sentences - 1
self.batch_indexes = k2.RaggedTensor(batch_indexes)
def __len__(self) -> int:
"""Return number of batches in this dataset"""
return self.batch_indexes.dim0
def __getitem__(self, i: int) -> k2.RaggedTensor:
"""Get the i'th batch in this dataset
Return a ragged tensor with 2 axes [sentence][token].
"""
assert 0 <= i < len(self), i
# indexes is a 1-D tensor containing sentence indexes
indexes = self.batch_indexes[i]
# sentence_words is a ragged tensor with 2 axes
# [sentence][word]
sentence_words = self.sentences[indexes]
# in case indexes contains only 1 entry, the returned
# sentence_words is a 1-D tensor, we have to convert
# it to a ragged tensor
if isinstance(sentence_words, torch.Tensor):
sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0))
# sentence_word_tokens is a ragged tensor with 3 axes
# [sentence][word][token]
sentence_word_tokens = self.words.index(sentence_words)
assert sentence_word_tokens.num_axes == 3
sentence_tokens = sentence_word_tokens.remove_axis(1)
return sentence_tokens
class LmDatasetCollate:
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
"""
Args:
sos_id:
Token ID of the SOS symbol.
eos_id:
Token ID of the EOS symbol.
blank_id:
Token ID of the blank symbol.
"""
self.sos_id = sos_id
self.eos_id = eos_id
self.blank_id = blank_id
def __call__(
self, batch: List[k2.RaggedTensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return a tuple containing 3 tensors:
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence starting with `self.sos_id`. It is padded to
the max sentence length with `self.blank_id`.
- y, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence ending with `self.eos_id` before padding.
Then it is padded to the max sentence length with
`self.blank_id`.
- lengths, a 2-D tensor of dtype torch.int32, containing the number of
tokens of each sentence before padding.
"""
# The batching stuff has already been done in LmDataset
assert len(batch) == 1
sentence_tokens = batch[0]
row_splits = sentence_tokens.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id)
y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id)
sentence_token_lengths += 1 # plus 1 since we added a SOS
return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
def get_dataloader(
filename: str,
is_distributed: bool,
params: AttributeDict,
) -> torch.utils.data.DataLoader:
"""Get dataloader for LM training.
Args:
filename:
Path to the file containing LM data. The file is assumed to
be generated by `../local/sort_lm_training_data.py`.
is_distributed:
True if using DDP training. False otherwise.
params:
Set `get_params()` from `rnn_lm/train.py`
Returns:
Return a dataloader containing the LM data.
"""
lm_data = torch.load(filename)
words = lm_data["words"]
sentences = lm_data["sentences"]
sentence_lengths = lm_data["sentence_lengths"]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=params.max_sent_len,
batch_size=params.batch_size,
)
if is_distributed:
sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
else:
sampler = None
collate_fn = LmDatasetCollate(
sos_id=params.sos_id,
eos_id=params.eos_id,
blank_id=params.blank_id,
)
dataloader = DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=sampler is None,
)
return dataloader

View File

@ -1,395 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation
import argparse
import logging
from pathlib import Path
from typing import Dict
import onnx
import torch
from model import RnnLmModel
from onnxruntime.quantization import QuantType, quantize_dynamic
from train import get_params
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, str2bool
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
# A wrapper for RnnLm model to simpily the C++ calling code
# when exporting the model to ONNX.
#
# TODO(fangjun): The current wrapper works only for non-streaming ASR
# since we don't expose the LM state and it is used to score
# a complete sentence at once.
class RnnLmModelWrapper(torch.nn.Module):
def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int):
super().__init__()
self.model = model
self.sos_id = sos_id
self.eos_id = eos_id
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (N, L) with dtype torch.int64.
It does not contain SOS or EOS. We will add SOS and EOS inside
this function.
x_lens:
A 1-D tensor of shape (N,) with dtype torch.int64. It contains
number of valid tokens in ``x`` before padding.
Returns:
Return a 1-D tensor of shape (N,) containing negative loglikelihood.
Its dtype is torch.float32
"""
N = x.size(0)
sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand(
N, 1
)
sos_x = torch.cat([sos_tensor, x], dim=1)
pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1)
x_eos = torch.cat([x, pad_col], dim=1)
row_index = torch.arange(0, N, dtype=x.dtype)
x_eos[row_index, x_lens] = self.eos_id
# use x_lens + 1 here since we prepended x with sos
return (
self.model(x=sos_x, y=x_eos, lengths=x_lens + 1)
.to(torch.float32)
.sum(dim=1)
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
return parser
def export_without_state(
model: RnnLmModel,
filename: str,
params: AttributeDict,
opset_version: int,
):
model_wrapper = RnnLmModelWrapper(
model,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
N = 1
L = 20
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
x_lens = torch.full((N,), fill_value=L, dtype=torch.int64)
# Note(fangjun): The following warnings can be ignored.
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
"""
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
with a batch_size other than 1, with a variable length with LSTM can cause
an error when running the ONNX model with a different batch size. Make sure
to save the model with a batch size of 1, or define the initial states
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
with a batch_size other than 1, " +
"""
torch.onnx.export(
model_wrapper,
(x, x_lens),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["nll"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_lens": {0: "N"},
"nll": {0: "N"},
},
)
meta_data = {
"model_type": "rnnlm",
"version": "1",
"model_author": "k2-fsa",
"comment": "rnnlm without state",
"sos_id": str(params.sos_id),
"eos_id": str(params.eos_id),
"vocab_size": str(params.vocab_size),
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=filename, meta_data=meta_data)
def export_with_state(
model: RnnLmModel,
filename: str,
params: AttributeDict,
opset_version: int,
):
N = 1
L = 20
num_layers = model.rnn.num_layers
hidden_size = model.rnn.hidden_size
embedding_dim = model.embedding_dim
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
h0 = torch.zeros(num_layers, N, hidden_size)
c0 = torch.zeros(num_layers, N, hidden_size)
# Note(fangjun): The following warnings can be ignored.
# We can use ./check-onnx.py to validate the exported model with batch_size > 1
"""
torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX
with a batch_size other than 1, with a variable length with LSTM can cause
an error when running the ONNX model with a different batch size. Make sure
to save the model with a batch size of 1, or define the initial states
(h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX
with a batch_size other than 1, " +
"""
torch.onnx.export(
model,
(x, h0, c0),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "h0", "c0"],
output_names=["log_softmax", "next_h0", "next_c0"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"h0": {1: "N"},
"c0": {1: "N"},
"log_softmax": {0: "N"},
"next_h0": {1: "N"},
"next_c0": {1: "N"},
},
)
meta_data = {
"model_type": "rnnlm",
"version": "1",
"model_author": "k2-fsa",
"comment": "rnnlm state",
"sos_id": str(params.sos_id),
"eos_id": str(params.eos_id),
"vocab_size": str(params.vocab_size),
"num_layers": str(num_layers),
"hidden_size": str(hidden_size),
"embedding_dim": str(embedding_dim),
"url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
device = torch.device("cpu")
logging.info(f"device: {device}")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
model.to(device)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting model without state")
filename = params.exp_dir / f"no-state-{suffix}.onnx"
export_without_state(
model=model,
filename=filename,
params=params,
opset_version=opset_version,
)
filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
weight_type=QuantType.QInt8,
)
# now for streaming export
saved_forward = model.__class__.forward
model.__class__.forward = model.__class__.score_token_onnx
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
export_with_state(
model=model,
filename=streaming_filename,
params=params,
opset_version=opset_version,
)
model.__class__.forward = saved_forward
streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx"
quantize_dynamic(
model_input=streaming_filename,
model_output=streaming_filename_int8,
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,26 +0,0 @@
#!/usr/bin/env bash
# We use the model from
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
# as an example
export CUDA_VISIBLE_DEVICES=
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
pushd icefall-librispeech-rnn-lm/exp
git lfs pull --include "pretrained.pt"
ln -s pretrained.pt epoch-99.pt
popd
fi
python3 ./export-onnx.py \
--exp-dir ./icefall-librispeech-rnn-lm/exp \
--epoch 99 \
--avg 1 \
--vocab-size 500 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--tie-weights 1

View File

@ -1,205 +0,0 @@
#!/usr/bin/env python3
#
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from model import RnnLmModel
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict({})
params.update(vars(args))
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
model.to(device)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model.__class__.score_token_onnx = torch.jit.export(
model.__class__.score_token_onnx
)
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,27 +0,0 @@
#!/usr/bin/env bash
# We use the model from
# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main
# as an example
export CUDA_VISIBLE_DEVICES=
if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
pushd icefall-librispeech-rnn-lm/exp
git lfs pull --include "pretrained.pt"
ln -s pretrained.pt epoch-99.pt
popd
fi
python3 ./export.py \
--exp-dir ./icefall-librispeech-rnn-lm/exp \
--epoch 99 \
--avg 1 \
--vocab-size 500 \
--embedding-dim 2048 \
--hidden-dim 2048 \
--num-layers 3 \
--tie-weights 1 \
--jit 1

View File

@ -1,295 +0,0 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Tuple
import torch
import torch.nn.functional as F
from icefall.utils import add_eos, add_sos, make_pad_mask
class RnnLmModel(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
hidden_dim: int,
num_layers: int,
tie_weights: bool = False,
):
"""
Args:
vocab_size:
Vocabulary size of BPE model.
embedding_dim:
Input embedding dimension.
hidden_dim:
Hidden dimension of RNN layers.
num_layers:
Number of RNN layers.
tie_weights:
True to share the weights between the input embedding layer and the
last output linear layer. See https://arxiv.org/abs/1608.05859
and https://arxiv.org/abs/1611.01462
"""
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.tie_weights = tie_weights
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
)
self.rnn = torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
)
self.output_linear = torch.nn.Linear(
in_features=hidden_dim, out_features=vocab_size
)
self.vocab_size = vocab_size
if tie_weights:
logging.info("Tying weights")
assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim)
self.output_linear.weight = self.input_embedding.weight
else:
logging.info("Not tying weights")
self.cache = {}
def streaming_forward(
self,
x: torch.Tensor,
y: torch.Tensor,
h0: torch.Tensor,
c0: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 2-D tensor of shape (N, L). We won't prepend it with SOS.
y:
A 2-D tensor of shape (N, L). We won't append it with EOS.
h0:
A 3-D tensor of shape (num_layers, N, hidden_size).
(If proj_size > 0, then it is (num_layers, N, proj_size))
c0:
A 3-D tensor of shape (num_layers, N, hidden_size).
Returns:
Return a tuple containing 3 tensors:
- negative loglike (nll), a 1-D tensor of shape (N,)
- next_h0, a 3-D tensor with the same shape as h0
- next_c0, a 3-D tensor with the same shape as c0
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert x.shape == y.shape, (x.shape, y.shape)
# embedding is of shape (N, L, embedding_dim)
embedding = self.input_embedding(x)
# Note: We use batch_first==True
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0))
logits = self.output_linear(rnn_out)
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
batch_size = x.size(0)
nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1)
return nll_loss, next_h0, next_c0
def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor with shape (N, L). Each row
contains token IDs for a sentence and starts with the SOS token.
y:
A shifted version of `x` and with EOS appended.
lengths:
A 1-D tensor of shape (N,). It contains the sentence lengths
before padding.
Returns:
Return a 2-D tensor of shape (N, L) containing negative log-likelihood
loss values. Note: Loss values for padding positions are set to 0.
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert lengths.ndim == 1, lengths.ndim
assert x.shape == y.shape, (x.shape, y.shape)
batch_size = x.size(0)
assert lengths.size(0) == batch_size, (lengths.size(0), batch_size)
# embedding is of shape (N, L, embedding_dim)
embedding = self.input_embedding(x)
# Note: We use batch_first==True
rnn_out, _ = self.rnn(embedding)
logits = self.output_linear(rnn_out)
# Note: No need to use `log_softmax()` here
# since F.cross_entropy() expects unnormalized probabilities
# nll_loss is of shape (N*L,)
# nll -> negative log-likelihood
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
# Set loss values for padding positions to 0
mask = make_pad_mask(lengths).reshape(-1)
nll_loss.masked_fill_(mask, 0)
nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
device = next(self.parameters()).device
batch_size = len(token_lens)
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding)
logits = self.output_linear(rnn_out)
mask = torch.zeros(logits.shape).bool().to(device)
for i in range(batch_size):
mask[i, token_lens[i], :] = True
logits = logits[mask].reshape(batch_size, -1)
return logits[:, :].log_softmax(-1), states
def clean_cache(self):
self.cache = {}
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
"""Score a batch of tokens, i.e each sample in the batch should be a
single token. For example, x = torch.tensor([[5],[10],[20]])
Args:
x (torch.Tensor):
A batch of tokens
x_lens (torch.Tensor):
The length of tokens in the batch before padding
state (optional):
Either None or a tuple of two torch.Tensor. Each tensor has
the shape of (num_layers, bs, hidden_dim)
Returns:
_type_: _description_
"""
device = next(self.parameters()).device
batch_size = x.size(0)
if state:
h, c = state
else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
device
)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to(
device
)
embedding = self.input_embedding(x)
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), states
def score_token_onnx(
self,
x: torch.Tensor,
state_h: torch.Tensor,
state_c: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Score a batch of tokens, i.e each sample in the batch should be a
single token. For example, x = torch.tensor([[5],[10],[20]])
Args:
x (torch.Tensor):
A batch of tokens
state_h:
state h of RNN has the shape of (num_layers, bs, hidden_dim)
state_c:
state c of RNN has the shape of (num_layers, bs, hidden_dim)
Returns:
_type_: _description_
"""
embedding = self.input_embedding(x)
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), next_h0, next_c0
def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):
batch_size = len(token_lens)
if state:
h, c = state
else:
h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
device = next(self.parameters()).device
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)
return logits, states

View File

@ -1,71 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
from rnn_lm.dataset import LmDataset, LmDatasetCollate
def main():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=collate_fn
)
for i in dataloader:
print(i)
# I've checked the output manually; the output is as expected.
if __name__ == "__main__":
main()

View File

@ -1,103 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import k2
import torch
import torch.multiprocessing as mp
from rnn_lm.dataset import LmDataset, LmDatasetCollate
from torch import distributed as dist
def generate_data():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
return sentences, words, sentence_lengths
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12352"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
sentences, words, sentence_lengths = generate_data()
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=False,
)
for i in dataloader:
print(f"rank: {rank}", i)
dist.destroy_process_group()
def main():
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,69 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from rnn_lm.model import RnnLmModel
def test_rnn_lm_model():
vocab_size = 4
model = RnnLmModel(
vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2
)
x = torch.tensor(
[
[1, 3, 2, 2],
[1, 2, 2, 0],
[1, 2, 0, 0],
]
)
y = torch.tensor(
[
[3, 2, 2, 1],
[2, 2, 1, 0],
[2, 1, 0, 0],
]
)
lengths = torch.tensor([4, 3, 2])
nll_loss = model(x, y, lengths)
print(nll_loss)
"""
tensor([[1.1180, 1.3059, 1.2426, 1.7773],
[1.4231, 1.2783, 1.7321, 0.0000],
[1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
"""
def test_rnn_lm_model_tie_weights():
model = RnnLmModel(
vocab_size=10,
embedding_dim=10,
hidden_dim=10,
num_layers=2,
tie_weights=True,
)
assert model.input_embedding.weight is model.output_linear.weight
def main():
test_rnn_lm_model()
test_rnn_lm_model_tie_weights()
if __name__ == "__main__":
torch.manual_seed(20211122)
main()

View File

@ -1,689 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--tie-weights 0 \
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2 \
--batch-size 400
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from dataset import get_dataloader
from lhotse.utils import fix_random_seed
from model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=400,
)
parser.add_argument(
"--lm-data",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data.pt",
help="LM training data",
)
parser.add_argument(
"--lm-data-valid",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
help="LM validation data",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=3,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--lr",
type=float,
default=1e-3,
)
parser.add_argument(
"--max-sent-len",
type=int,
default=200,
help="""Maximum number of tokens in a sentence. This is used
to adjust batch-size dynamically""",
)
parser.add_argument(
"--save-every-n",
type=int,
default=2000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
"weight_decay": 1e-6,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 100,
"reset_interval": 2000,
"valid_interval": 200,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model, e.g., RnnLmModel.
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
rank=rank,
)
del params.cur_batch_idx
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
is_distributed = world_size > 1
fix_random_seed(params.seed)
if is_distributed:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if is_distributed:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=is_distributed,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=is_distributed,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if is_distributed:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if is_distributed:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,510 +0,0 @@
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional, Tuple
import torch
from torch import Tensor, nn
from icefall.transformer_lm.scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from icefall.utils import is_jit_tracing
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.25
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
return self.multi_head_attention_forward(
query,
key,
value,
pos_emb,
self.embed_dim,
self.num_heads,
self.in_proj.get_weight(),
self.in_proj.get_bias(),
self.dropout,
self.out_proj.get_weight(),
self.out_proj.get_bias(),
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
left_context=left_context,
)
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
if not is_jit_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if is_jit_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
pos_emb: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
pos_emb: Positional embedding tensor
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
length, N is the batch size, E is the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
if not is_jit_tracing():
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
if not is_jit_tracing():
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if torch.equal(query, key) and torch.equal(key, value):
# self-attention
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(
3, dim=-1
)
elif torch.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = nn.functional.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32
or attn_mask.dtype == torch.float64
or attn_mask.dtype == torch.float16
or attn_mask.dtype == torch.uint8
or attn_mask.dtype == torch.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
)
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * num_heads,
query.size(0),
key.size(0),
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError(
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
)
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn(
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(torch.bool)
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
src_len = k.size(0)
if key_padding_mask is not None and not is_jit_tracing():
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
if not is_jit_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd, left_context)
attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
if not is_jit_tracing():
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if (
attn_mask is not None
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
attn_output = torch.bmm(attn_output_weights, v)
if not is_jit_tracing():
assert list(attn_output.size()) == [
bsz * num_heads,
tgt_len,
head_dim,
]
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
)
attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None

View File

@ -1,195 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from pathlib import Path
import torch
from dataset import get_dataloader
from train import get_params
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=7,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transformer_lm/exp_full_libri_16layer_maxlen200_8gpu",
)
parser.add_argument(
"--lm-data",
type=str,
help="Path to the LM test data for computing perplexity",
default="transformer_lm/libri_lm_training_bpe500/sorted_lm_data-test.pt",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=16,
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of RNN layers the model",
)
parser.add_argument(
"--max-sent-len",
type=int,
default=100,
help="Number of RNN layers the model",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_data = Path(args.lm_data)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-ppl/")
logging.info("Computing perplexity started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.encoder_dim,
embedding_dim=params.embedding_dim,
dim_feedforward=params.dim_feedforward,
nhead=params.nhead,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
params=params,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(
[p.numel() for p in model.parameters() if p.requires_grad]
)
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Number of model parameters (requires_grad): "
f"{num_param_requires_grad} "
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
)
logging.info(f"Loading LM test data from {params.lm_data}")
test_dl = get_dataloader(
filename=params.lm_data,
is_distributed=False,
params=params,
)
tot_loss = 0.0
num_tokens = 0
num_sentences = 0
for batch_idx, batch in enumerate(test_dl):
x, y, sentence_lengths = batch
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum().cpu().item()
tot_loss += loss
num_tokens += sentence_lengths.sum().cpu().item()
num_sentences += x.size(0)
ppl = math.exp(tot_loss / num_tokens)
logging.info(
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
)
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
../rnn_lm/dataset.py

View File

@ -1,329 +0,0 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from icefall.transformer_lm.attention import RelPositionMultiheadAttention
from icefall.transformer_lm.scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
)
from icefall.utils import is_jit_tracing, make_pad_mask
class Transformer(torch.nn.Module):
"""_summary_
Args:
input_dim (int): Input feature dimension
d_mode (int): The dimension of the transformer
dim_feedforward (int ): The dimension of the ffw module
nhead (int): The number of attention heads
dropout_rate (float): dropout rate
att_dropout (float): dropout rate in attention module
"""
def __init__(
self,
input_dim: int,
d_model: int,
dim_feedforward: int,
nhead: int = 4,
num_layers: int = 6,
dropout_rate: float = 0.1,
att_dropout: float = 0.0,
):
super().__init__()
self.encoder_layers = num_layers
self.d_model = d_model
self.embed = ScaledLinear(input_dim, d_model)
self.norm_before = BasicNorm(d_model, learn_eps=False)
self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
dim_feedforward=dim_feedforward,
nhead=nhead,
dropout_rate=dropout_rate,
)
self.encoder = TransformerEncoder(encoder_layer, num_layers)
def _create_attention_mask(self, x_lens: torch.Tensor):
# create a 2D attention mask to mask out
# the upper right half of the attention matrix
max_len = max(x_lens)
ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool)
return torch.triu(ones, diagonal=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Transformer forward
Args:
x (torch.Tensor): Input tensor (B,T,input_dim)
x_lens (torch.Tensor): The length of input tensors before padding (B,)
Returns:
Return a tuple of 2 tensors:
- x: output feature of the transformer (B,T,d_model)
- x_lens: output feature lens of the transformer
"""
attention_mask = self._create_attention_mask(x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = self.norm_before(self.embed(x))
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2)
x = self.encoder(
x,
pos_emb,
mask=attention_mask, # pass the attention mast
src_key_padding_mask=src_key_padding_mask,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, x_lens
class TransformerEncoder(torch.nn.Module):
def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None:
"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer()
num_layers (int): Number of layers to be stacked
"""
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
def forward(
self,
src: torch.Tensor,
pos_emb: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""_summary_
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Returns:
output: transformer encoded features
"""
output = src
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
src_key_padding_mask=src_key_padding_mask,
src_mask=mask,
)
return output
class TransformerEncoderLayer(torch.nn.Module):
def __init__(
self,
d_model: int,
dim_feedforward: int,
nhead: int,
dropout_rate: float,
):
"""TransformerEncoderLayer is made up of self-attn and feedforward module
Args:
d_model (int): The model size
dim_feedforward (int): Dimension of ffw module
nhead (int): Number of heads
dropout_rate (float): Dropout rate
"""
super().__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout_rate),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.norm_final = BasicNorm(d_model)
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
src: torch.Tensor,
pos_emb: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
cache=None,
):
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_key_padding_mask: the mask for the src keys per batch (optional).
src_mask: the mask for the src sequence (optional).
"""
src_orig = src
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
return src
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(
self,
x: torch.Tensor,
left_context: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)

View File

@ -1,186 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
import argparse
import logging
from pathlib import Path
import torch
from model import TransformerLM
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, load_averaged_model, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=11,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=768,
help="Embedding dim of the model",
)
parser.add_argument(
"--encoder-dim",
type=int,
default=768,
help="Encoder dim of the model",
)
parser.add_argument(
"--dim_feedforward",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads",
)
parser.add_argument(
"--num-layers",
type=int,
default=16,
help="Number of Transformer layers",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict({})
params.update(vars(args))
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("About to create model")
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.encoder_dim,
embedding_dim=params.embedding_dim,
dim_feedforward=params.dim_feedforward,
nhead=params.nhead,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
params=params,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
model = load_averaged_model(
params.exp_dir, model, params.epoch, params.avg, device
)
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,114 +0,0 @@
# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from icefall.transformer_lm.encoder import Transformer
from icefall.utils import AttributeDict, add_eos, add_sos, make_pad_mask
class TransformerLM(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
d_model: int,
dim_feedforward: int,
nhead: int = 8,
num_layers: int = 16,
tie_weights: bool = True,
dropout: float = 0.1,
emb_dropout_rate: float = 0.0,
params: AttributeDict = None,
):
super().__init__()
self.vocab_size = vocab_size
self.params = params
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
)
self.encoder = Transformer(
input_dim=embedding_dim,
d_model=d_model,
dim_feedforward=dim_feedforward,
nhead=nhead,
num_layers=num_layers,
dropout_rate=dropout,
)
self.output_linear = torch.nn.Linear(
in_features=d_model, out_features=vocab_size
)
if tie_weights:
logging.info("Tying weights")
assert d_model == embedding_dim, (d_model, embedding_dim)
self.output_linear.weight = self.input_embedding.weight
else:
logging.info("Not tying weights")
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
x_lens: torch.Tensor,
return_logits: bool = False,
):
"""Forward transformer language model
Args:
x (torch.Tensor): Input tokens (B,L)
y (torch.Tensor): Output tokens (with EOS appended) (B,L)
x_lens (torch.Tensor): Length of input tokens before padding (B,)
return_logits (bool, optional): Return logits instead of NLL
"""
x = self.input_embedding(x)
x, x_lens = self.encoder(x, x_lens)
logits = self.output_linear(x)
if return_logits:
return logits
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
mask = make_pad_mask(x_lens).reshape(-1)
nll_loss.masked_fill_(mask, 0)
return nll_loss
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
bs = x.size(0)
state = None
logits = self.forward(x, x, x_lens, return_logits=True)
index = torch.arange(bs)
last_logits = logits[index, x_lens - 1, :]
return last_logits.log_softmax(-1), state

View File

@ -1 +0,0 @@
../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py

View File

@ -1,609 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./transformer_lm/train.py \
--start-epoch 0 \
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--num-layers 12 \
--batch-size 400
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from dataset import get_dataloader
from lhotse.utils import fix_random_seed
from model import TransformerLM
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=30,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transformer_lm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=True,
help="Whether to use half precision training.",
)
parser.add_argument(
"--batch-size",
type=int,
default=400,
)
parser.add_argument(
"--lm-data",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data.pt",
help="LM training data",
)
parser.add_argument(
"--lm-data-valid",
type=str,
default="data/lm_training_bpe_500/sorted_lm_data-valid.pt",
help="LM validation data",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=12,
help="Number of Transformer layers in the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
"lr": 1e-3,
"weight_decay": 1e-6,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 200,
"reset_interval": 2000,
"valid_interval": 1000,
"nhead": 8,
"embedding_dim": 768,
"encoder_dim": 768,
"dim_feedforward": 2048,
"dropout": 0.1,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model,
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
is_distributed = world_size > 1
fix_random_seed(params.seed)
if is_distributed:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = TransformerLM(
vocab_size=params.vocab_size,
d_model=params.encoder_dim,
embedding_dim=params.embedding_dim,
dim_feedforward=params.dim_feedforward,
nhead=params.nhead,
num_layers=params.num_layers,
tie_weights=params.tie_weights,
params=params,
)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if is_distributed:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=is_distributed,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=is_distributed,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if is_distributed:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if is_distributed:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()