Remove rnn_lm directory from egs

This commit is contained in:
Erwan 2022-06-22 10:47:50 +02:00
parent 8d2f4ba098
commit 73a31db1b2
8 changed files with 0 additions and 1504 deletions

View File

@ -1,228 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/compute_perplexity.py \
--epoch 4 \
--avg 2 \
--lm-data ./data/bpe_500/sorted_lm_data-test.pt
"""
import argparse
import logging
import math
from pathlib import Path
import torch
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=49,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="The experiment dir",
)
parser.add_argument(
"--lm-data",
type=str,
help="Path to the LM test data for computing perplexity",
)
parser.add_argument(
"--vocab-size",
type=int,
default=500,
help="Vocabulary size of the model",
)
parser.add_argument(
"--embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Number of RNN layers the model",
)
parser.add_argument(
"--max-sent-len",
type=int,
default=100,
help="Number of RNN layers the model",
)
parser.add_argument(
"--sos-id",
type=int,
default=1,
help="SOS ID",
)
parser.add_argument(
"--eos-id",
type=int,
default=1,
help="EOS ID",
)
parser.add_argument(
"--blank-id",
type=int,
default=0,
help="Blank ID",
)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lm_data = Path(args.lm_data)
params = AttributeDict(vars(args))
print(params)
setup_logger(f"{params.exp_dir}/log-ppl/")
logging.info("Computing perplexity started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
num_param_requires_grad = sum(
[p.numel() for p in model.parameters() if p.requires_grad]
)
logging.info(f"Number of model parameters: {num_param}")
logging.info(
f"Number of model parameters (requires_grad): "
f"{num_param_requires_grad} "
f"({num_param_requires_grad/num_param_requires_grad*100}%)"
)
logging.info(f"Loading LM test data from {params.lm_data}")
test_dl = get_dataloader(
filename=params.lm_data,
is_distributed=False,
params=params,
)
tot_loss = 0.0
num_tokens = 0
num_sentences = 0
for batch_idx, batch in enumerate(test_dl):
x, y, sentence_lengths = batch
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum().cpu().item()
tot_loss += loss
num_tokens += sentence_lengths.sum().cpu().item()
num_sentences += x.size(0)
ppl = math.exp(tot_loss / num_tokens)
logging.info(
f"total nll: {tot_loss}, num tokens: {num_tokens}, "
f"num sentences: {num_sentences}, ppl: {ppl:.3f}"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,316 +0,0 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import k2
import torch
from icefall.utils import AttributeDict
class LmDataset(torch.utils.data.Dataset):
def __init__(
self,
sentences: k2.RaggedTensor,
words: k2.RaggedTensor,
sentence_lengths: torch.Tensor,
max_sent_len: int,
batch_size: int,
):
"""
Args:
sentences:
A ragged tensor of dtype torch.int32 with 2 axes [sentence][word].
words:
A ragged tensor of dtype torch.int32 with 2 axes [word][token].
sentence_lengths:
A 1-D tensor of dtype torch.int32 containing number of tokens
of each sentence.
max_sent_len:
Maximum sentence length. It is used to change the batch size
dynamically. In general, we try to keep the product of
"max_sent_len in a batch" and "num_of_sent in a batch" being
a constant.
batch_size:
The expected batch size. It is changed dynamically according
to the "max_sent_len".
See `../local/prepare_lm_training_data.py` for how `sentences` and
`words` are generated. We assume that `sentences` are sorted by length.
See `../local/sort_lm_training_data.py`.
"""
super().__init__()
self.sentences = sentences
self.words = words
sentence_lengths = sentence_lengths.tolist()
assert batch_size > 0, batch_size
assert max_sent_len > 1, max_sent_len
batch_indexes = []
num_sentences = sentences.dim0
cur = 0
while cur < num_sentences:
sz = sentence_lengths[cur] // max_sent_len + 1
# Assume the current sentence has 3 * max_sent_len tokens,
# in the worst case, the subsequent sentences also have
# this number of tokens, we should reduce the batch size
# so that this batch will not contain too many tokens
actucal_batch_size = batch_size // sz + 1
actucal_batch_size = min(actucal_batch_size, batch_size)
end = cur + actucal_batch_size
end = min(end, num_sentences)
this_batch_indexes = torch.arange(cur, end).tolist()
batch_indexes.append(this_batch_indexes)
cur = end
assert batch_indexes[-1][-1] == num_sentences - 1
self.batch_indexes = k2.RaggedTensor(batch_indexes)
def __len__(self) -> int:
"""Return number of batches in this dataset"""
return self.batch_indexes.dim0
def __getitem__(self, i: int) -> k2.RaggedTensor:
"""Get the i'th batch in this dataset
Return a ragged tensor with 2 axes [sentence][token].
"""
assert 0 <= i < len(self), i
# indexes is a 1-D tensor containing sentence indexes
indexes = self.batch_indexes[i]
# sentence_words is a ragged tensor with 2 axes
# [sentence][word]
sentence_words = self.sentences[indexes]
# in case indexes contains only 1 entry, the returned
# sentence_words is a 1-D tensor, we have to convert
# it to a ragged tensor
if isinstance(sentence_words, torch.Tensor):
sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0))
# sentence_word_tokens is a ragged tensor with 3 axes
# [sentence][word][token]
sentence_word_tokens = self.words.index(sentence_words)
assert sentence_word_tokens.num_axes == 3
sentence_tokens = sentence_word_tokens.remove_axis(1)
return sentence_tokens
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
"""Add SOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
sos_id:
The ID of the SOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with SOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_sos(a, sos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, sos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
eos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
class LmDatasetCollate:
def __init__(self, sos_id: int, eos_id: int, blank_id: int):
"""
Args:
sos_id:
Token ID of the SOS symbol.
eos_id:
Token ID of the EOS symbol.
blank_id:
Token ID of the blank symbol.
"""
self.sos_id = sos_id
self.eos_id = eos_id
self.blank_id = blank_id
def __call__(
self, batch: List[k2.RaggedTensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return a tuple containing 3 tensors:
- x, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence starting with `self.sos_id`. It is padded to
the max sentence length with `self.blank_id`.
- y, a 2-D tensor of dtype torch.int32; each row contains tokens
for a sentence ending with `self.eos_id` before padding.
Then it is padded to the max sentence length with
`self.blank_id`.
- lengths, a 2-D tensor of dtype torch.int32, containing the number of
tokens of each sentence before padding.
"""
# The batching stuff has already been done in LmDataset
assert len(batch) == 1
sentence_tokens = batch[0]
row_splits = sentence_tokens.shape.row_splits(1)
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id)
sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id)
x = sentence_tokens_with_sos.pad(
mode="constant", padding_value=self.blank_id
)
y = sentence_tokens_with_eos.pad(
mode="constant", padding_value=self.blank_id
)
sentence_token_lengths += 1 # plus 1 since we added a SOS
return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths
def get_dataloader(
filename: str,
is_distributed: bool,
params: AttributeDict,
) -> torch.utils.data.DataLoader:
"""Get dataloader for LM training.
Args:
filename:
Path to the file containing LM data. The file is assumed to
be generated by `../local/sort_lm_training_data.py`.
is_distributed:
True if using DDP training. False otherwise.
params:
Set `get_params()` from `rnn_lm/train.py`
Returns:
Return a dataloader containing the LM data.
"""
lm_data = torch.load(filename)
words = lm_data["words"]
sentences = lm_data["sentences"]
sentence_lengths = lm_data["sentence_lengths"]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=params.max_sent_len,
batch_size=params.batch_size,
)
if is_distributed:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
else:
sampler = None
collate_fn = LmDatasetCollate(
sos_id=params.sos_id,
eos_id=params.eos_id,
blank_id=params.blank_id,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=sampler is None,
)
return dataloader

View File

@ -1,145 +0,0 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
import torch.nn.functional as F
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = lengths.max()
n = lengths.size(0)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)
class RnnLmModel(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
hidden_dim: int,
num_layers: int,
tie_weights: bool = False,
):
"""
Args:
vocab_size:
Vocabulary size of BPE model.
embedding_dim:
Input embedding dimension.
hidden_dim:
Hidden dimension of RNN layers.
num_layers:
Number of RNN layers.
tie_weights:
True to share the weights between the input embedding layer and the
last output linear layer. See https://arxiv.org/abs/1608.05859
and https://arxiv.org/abs/1611.01462
"""
super().__init__()
self.input_embedding = torch.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
)
self.rnn = torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
)
self.output_linear = torch.nn.Linear(
in_features=hidden_dim, out_features=vocab_size
)
self.vocab_size = vocab_size
if tie_weights:
logging.info("Tying weights")
assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim)
self.output_linear.weight = self.input_embedding.weight
else:
logging.info("Not tying weights")
def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor with shape (N, L). Each row
contains token IDs for a sentence and starts with the SOS token.
y:
A shifted version of `x` and with EOS appended.
lengths:
A 1-D tensor of shape (N,). It contains the sentence lengths
before padding.
Returns:
Return a 2-D tensor of shape (N, L) containing negative log-likelihood
loss values. Note: Loss values for padding positions are set to 0.
"""
assert x.ndim == y.ndim == 2, (x.ndim, y.ndim)
assert lengths.ndim == 1, lengths.ndim
assert x.shape == y.shape, (x.shape, y.shape)
batch_size = x.size(0)
assert lengths.size(0) == batch_size, (lengths.size(0), batch_size)
# embedding is of shape (N, L, embedding_dim)
embedding = self.input_embedding(x)
# Note: We use batch_first==True
rnn_out, _ = self.rnn(embedding)
logits = self.output_linear(rnn_out)
# Note: No need to use `log_softmax()` here
# since F.cross_entropy() expects unnormalized probabilities
# nll_loss is of shape (N*L,)
# nll -> negative log-likelihood
nll_loss = F.cross_entropy(
logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none"
)
# Set loss values for padding positions to 0
mask = make_pad_mask(lengths).reshape(-1)
nll_loss.masked_fill_(mask, 0)
nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss

View File

@ -1,74 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
from rnn_lm.dataset import LmDataset, LmDatasetCollate
def main():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
print(dataset.sentences)
print(dataset.words)
print(dataset.batch_indexes)
print(len(dataset))
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, collate_fn=collate_fn
)
for i in dataloader:
print(i)
# I've checked the output manually; the output is as expected.
if __name__ == "__main__":
main()

View File

@ -1,103 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import k2
import torch
import torch.multiprocessing as mp
from rnn_lm.dataset import LmDataset, LmDatasetCollate
from torch import distributed as dist
def generate_data():
sentences = k2.RaggedTensor(
[[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]]
)
words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]])
num_sentences = sentences.dim0
sentence_lengths = [0] * num_sentences
for i in range(num_sentences):
word_ids = sentences[i]
# NOTE: If word_ids is a tensor with only 1 entry,
# token_ids is a torch.Tensor
token_ids = words[word_ids]
if isinstance(token_ids, k2.RaggedTensor):
token_ids = token_ids.values
# token_ids is a 1-D tensor containing the BPE tokens
# of the current sentence
sentence_lengths[i] = token_ids.numel()
sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32)
indices = torch.argsort(sentence_lengths, descending=True)
sentences = sentences[indices.to(torch.int32)]
sentence_lengths = sentence_lengths[indices]
return sentences, words, sentence_lengths
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12352"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
sentences, words, sentence_lengths = generate_data()
dataset = LmDataset(
sentences=sentences,
words=words,
sentence_lengths=sentence_lengths,
max_sent_len=3,
batch_size=4,
)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=True, drop_last=False
)
collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
collate_fn=collate_fn,
sampler=sampler,
shuffle=False,
)
for i in dataloader:
print(f"rank: {rank}", i)
dist.destroy_process_group()
def main():
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,84 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from rnn_lm.model import RnnLmModel, make_pad_mask
def test_makd_pad_mask():
lengths = torch.tensor([1, 3, 2])
mask = make_pad_mask(lengths)
expected = torch.tensor(
[
[False, True, True],
[False, False, False],
[False, False, True],
]
)
assert torch.all(torch.eq(mask, expected))
assert (~expected).sum() == lengths.sum()
def test_rnn_lm_model():
vocab_size = 4
model = RnnLmModel(
vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2
)
x = torch.tensor(
[
[1, 3, 2, 2],
[1, 2, 2, 0],
[1, 2, 0, 0],
]
)
y = torch.tensor(
[
[3, 2, 2, 1],
[2, 2, 1, 0],
[2, 1, 0, 0],
]
)
lengths = torch.tensor([4, 3, 2])
nll_loss = model(x, y, lengths)
print(nll_loss)
"""
tensor([[1.1180, 1.3059, 1.2426, 1.7773],
[1.4231, 1.2783, 1.7321, 0.0000],
[1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
"""
def test_rnn_lm_model_tie_weights():
model = RnnLmModel(
vocab_size=10,
embedding_dim=10,
hidden_dim=10,
num_layers=2,
tie_weights=True,
)
assert model.input_embedding.weight is model.output_linear.weight
def main():
test_makd_pad_mask()
test_rnn_lm_model()
test_rnn_lm_model_tie_weights()
if __name__ == "__main__":
torch.manual_seed(20211122)
main()

View File

@ -1,554 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./rnn_lm/train.py \
--world-size 2 \
--start-epoch 4
"""
import argparse
import logging
import math
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from lhotse.utils import fix_random_seed
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
exp_dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="rnn_lm/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, logs, etc, are saved
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters."""
params = AttributeDict(
{
# LM training/validation data
"lm_data": "data/bpe_500/sorted_lm_data.pt",
"lm_data_valid": "data/bpe_500/sorted_lm_data-valid.pt",
"batch_size": 50,
"max_sent_len": 200,
"sos_id": 1,
"eos_id": 1,
"blank_id": 0,
# model related
#
# vocab size of the BPE model
"vocab_size": 500,
"embedding_dim": 2048,
"hidden_dim": 2048,
"num_layers": 4,
#
"lr": 1e-3,
"weight_decay": 1e-6,
#
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 300,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
logging.info(f"Loading checkpoint: {filename}")
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
model: nn.Module,
x: torch.Tensor,
y: torch.Tensor,
sentence_lengths: torch.Tensor,
is_training: bool,
) -> Tuple[torch.Tensor, MetricsTracker]:
"""Compute the negative log-likelihood loss given a model and its input.
Args:
model:
The NN model, e.g., RnnLmModel.
x:
A 2-D tensor. Each row contains BPE token IDs for a sentence. Also,
each row starts with SOS ID.
y:
A 2-D tensor. Each row is a shifted version of the corresponding row
in `x` but ends with an EOS ID (before padding).
sentence_lengths:
A 1-D tensor containing number of tokens of each sentence
before padding.
is_training:
True for training. False for validation.
"""
with torch.set_grad_enabled(is_training):
device = model.device
x = x.to(device)
y = y.to(device)
sentence_lengths = sentence_lengths.to(device)
nll = model(x, y, sentence_lengths)
loss = nll.sum()
num_tokens = sentence_lengths.sum().item()
loss_info = MetricsTracker()
# Note: Due to how MetricsTracker() is designed,
# we use "frames" instead of "num_tokens" as a key here
loss_info["frames"] = num_tokens
loss_info["loss"] = loss.detach().item()
return loss, loss_info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
x, y, sentence_lengths = batch
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all sentences is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
x, y, sentence_lengths = batch
batch_size = x.size(0)
loss, loss_info = compute_loss(
model=model,
x=x,
y=y,
sentence_lengths=sentence_lengths,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
# Note: "frames" here means "num_tokens"
this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"])
tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"])
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] "
f"tot_loss[{tot_loss}, ppl: {tot_ppl}], "
f"batch size: {batch_size}"
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/current_ppl", this_batch_ppl, params.batch_idx_train
)
tb_writer.add_scalar(
"train/tot_ppl", tot_ppl, params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"])
logging.info(
f"Epoch {params.cur_epoch}, validation: {valid_info}, "
f"ppl: {valid_ppl}"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_scalar(
"train/valid_ppl", valid_ppl, params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
hidden_dim=params.hidden_dim,
num_layers=params.num_layers,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
logging.info("Load optimizer state_dict from checkpoint")
optimizer.load_state_dict(checkpoints["optimizer"])
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=world_size > 1,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=world_size > 1,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if world_size > 1:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()