RNN-T training for yesno. (#141)

* RNN-T training for yesno.

* Rename Jointer to Joiner.
This commit is contained in:
Fangjun Kuang 2021-12-07 21:44:37 +08:00 committed by GitHub
parent 1aff64b708
commit 95af039733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1722 additions and 1 deletions

2
.gitignore vendored
View File

@ -6,3 +6,5 @@ exp
exp*/
*.pt
download
*.bak
*-bak

View File

@ -487,6 +487,7 @@ def run(rank, world_size, args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"device: {device}")
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)

View File

View File

@ -0,0 +1 @@
../tdnn/asr_datamodule.py

View File

@ -0,0 +1,69 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from transducer.model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminte after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
id2word = {1: "YES", 2: "NO"}
hyp = [id2word[i] for i in hyp]
return hyp

View File

@ -0,0 +1,310 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import YesNoAsrDataModule
from transducer.beam_search import greedy_search
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=125,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="Directory from which to load the checkpoints",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 23,
# encoder/decoder params
"vocab_size": 3, # blank, yes, no
"blank_id": 0,
"embedding_dim": 32,
"hidden_dim": 16,
"num_decoder_layers": 4,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding.
- params.method is "nbest", it uses nbest decoding.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
Returns:
Return the decoding result. `len(ans)` == batch size.
"""
device = model.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
feature_lens = batch["supervisions"]["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
hyps.append(hyp)
# hyps = [[word_table[i] for i in ids] for ids in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
) -> List[Tuple[List[int], List[int]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a tuple contains two elements (ref_text, hyp_text):
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps = decode_one_batch(
params=params,
model=model,
batch=batch,
)
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
exp_dir: Path,
test_set_name: str,
results: List[Tuple[List[int], List[int]]],
) -> None:
"""Save results to `exp_dir`.
Args:
exp_dir:
The output directory. This function create the following files inside
this directory:
- recogs-{test_set_name}.text
It contains the reference and hypothesis results, like below::
ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
- errs-{test_set_name}.txt
It contains the detailed WER.
test_set_name:
The name of the test set, which will be part of the result filename.
results:
A list of tuples, each of which contains (ref_words, hyp_words).
Returns:
Return None.
"""
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, f"{test_set_name}", results)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
def get_transducer_model(params: AttributeDict):
encoder = Tdnn(
num_features=params.feature_dim,
output_dim=params.hidden_dim,
)
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.hidden_dim,
embedding_dropout=0.4,
rnn_dropout=0.4,
)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer
@torch.no_grad()
def main():
parser = get_parser()
YesNoAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
model.device = device
yes_no = YesNoAsrDataModule(args)
test_dl = yes_no.test_dataloaders()
results = decode_dataset(
dl=test_dl,
params=params,
model=model,
)
save_results(
exp_dir=params.exp_dir, test_set_name="test_set", results=results
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,92 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
import torch.nn as nn
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
num_layers: int,
hidden_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
num_layers:
Number of RNN layers.
hidden_dim:
Hidden dimension of RNN layers.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for RNN layers.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.embedding_dropout = nn.Dropout(embedding_dropout)
self.rnn = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.output_linear = nn.Linear(hidden_dim, hidden_dim)
def forward(
self,
y: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
y:
A 2-D tensor of shape (N, U).
states:
A tuple of two tensors containing the states information of
RNN layers in this decoder.
Returns:
Return a tuple containing:
- rnn_output, a tensor of shape (N, U, C)
- (h, c), which contain the state information for RNN layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -0,0 +1,87 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
# We use a TDNN model as encoder, as it works very well with CTC training
# for this tiny dataset.
class Tdnn(nn.Module):
def __init__(self, num_features: int, output_dim: int):
"""
Args:
num_features:
Model input dimension.
ouput_dim:
Model output dimension
"""
super().__init__()
# Note: We don't use paddings inside conv layers
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=32,
kernel_size=3,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=2,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
nn.Conv1d(
in_channels=32,
out_channels=32,
kernel_size=5,
dilation=4,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=32, affine=False),
)
self.output_linear = nn.Linear(in_features=32, out_features=output_dim)
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The input tensor with shape (N, T, C)
x_lens:
It contains the number of frames in each utterance in x
before padding.
Returns:
Return a tuple with 2 tensors:
- logits, a tensor of shape (N, T, C)
- logit_lens, a tensor of shape (N,)
"""
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.tdnn(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
logits = self.output_linear(x)
# the first conv layer reduces T by 3-1 frames
# the second layer reduces T by (5-1)*2 frames
# the second layer reduces T by (5-1)*4 frames
# Number of output frames is 2 + 4*2 + 4*4 = 2 + 8 + 16 = 26
x_lens = x_lens - 26
return logits, x_lens

View File

@ -0,0 +1,55 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
return output

View File

@ -0,0 +1,120 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out, _ = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
y_padded = y.pad(mode="constant", padding_value=0)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="mean",
)
return loss

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_decoder.py
"""
import torch
from transducer.decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
num_layers = 2
hidden_dim = 6
N = 3
U = 5
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
num_layers=num_layers,
hidden_dim=hidden_dim,
embedding_dropout=0.0,
rnn_dropout=0.0,
)
x = torch.randint(1, vocab_size, (N, U))
rnn_out, (h, c) = decoder(x)
assert rnn_out.shape == (N, U, hidden_dim)
assert h.shape == (num_layers, N, hidden_dim)
assert c.shape == (num_layers, N, hidden_dim)
rnn_out, (h, c) = decoder(x, (h, c))
assert rnn_out.shape == (N, U, hidden_dim)
assert h.shape == (num_layers, N, hidden_dim)
assert c.shape == (num_layers, N, hidden_dim)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_encoder.py
"""
import torch
from transducer.encoder import Tdnn
def test_encoder():
input_dim = 10
output_dim = 20
encoder = Tdnn(input_dim, output_dim)
N = 10
T = 85
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
logits, logit_lens = encoder(x, x_lens)
assert logits.shape == (N, T - 26, output_dim)
assert torch.all(torch.eq(x_lens - 26, logit_lens))
def main():
test_encoder()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_joiner.py
"""
import torch
from transducer.joiner import Joiner
def test_joiner():
N = 2
T = 3
C = 4
U = 5
joiner = Joiner(C, 10)
encoder_out = torch.rand(N, T, C)
decoder_out = torch.rand(N, U, C)
joint = joiner(encoder_out, decoder_out)
assert joint.shape == (N, T, U, 10)
def main():
test_joiner()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,77 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_transducer.py
"""
import k2
import torch
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.joiner import Joiner
from transducer.model import Transducer
def test_transducer():
# encoder params
input_dim = 10
output_dim = 20
# decoder params
vocab_size = 3
blank_id = 0
embedding_dim = 128
num_layers = 2
encoder = Tdnn(input_dim, output_dim)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
num_layers=num_layers,
hidden_dim=output_dim,
embedding_dropout=0.0,
rnn_dropout=0.0,
)
joiner = Joiner(output_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0
T = 50
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
x_lens[0] = T
loss = transducer(x, x_lens, y)
print(loss)
def main():
test_transducer()
if __name__ == "__main__":
main()

581
egs/yesno/ASR/transducer/train.py Executable file
View File

@ -0,0 +1,581 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import List, Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import YesNoAsrDataModule
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_labels(texts: List[str]) -> k2.RaggedTensor:
"""
Args:
texts:
A list of transcripts. Each transcript contains spaces separated
"NO" or "YES".
Returns:
Return a ragged tensor containing the corresponding word ID.
"""
# blank is 0
word2id = {"YES": 1, "NO": 2}
word_ids = []
for t in texts:
words = t.split()
ids = [word2id[w] for w in words]
word_ids.append(ids)
return k2.RaggedTensor(word_ids)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=200,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="Directory to save results",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
"""
params = AttributeDict(
{
"lr": 1e-3,
"feature_dim": 23,
"weight_decay": 1e-6,
"start_epoch": 0,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"reset_interval": 20,
"valid_interval": 10,
# encoder/decoder params
"vocab_size": 3, # blank, yes, no
"blank_id": 0,
"embedding_dim": 32,
"hidden_dim": 16,
"num_decoder_layers": 4,
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute RNN-T loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Tdnn in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
feature_lens = batch["supervisions"]["num_frames"].to(device)
texts = batch["supervisions"]["text"]
labels = get_labels(texts).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=labels)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = feature.size(0)
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
is_training=True,
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
valid_info = compute_validation_loss(
params=params,
model=model,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer,
"train/valid_",
params.batch_idx_train,
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def get_transducer_model(params: AttributeDict):
encoder = Tdnn(
num_features=params.feature_dim,
output_dim=params.hidden_dim,
)
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.hidden_dim,
embedding_dropout=0.4,
rnn_dropout=0.4,
)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"device: {device}")
model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = optim.Adam(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
yes_no = YesNoAsrDataModule(args)
train_dl = yes_no.train_dataloaders()
# There are only 60 waves: 30 files are used for training
# and the remaining 30 files are used for testing.
# We use test data as validation.
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=None,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
YesNoAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
main()

View File

@ -565,3 +565,128 @@ class MetricsTracker(collections.defaultdict):
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def concat(
ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor:
"""Prepend a value to the beginning of each sublist or append a value.
to the end of each sublist.
Args:
ragged:
A ragged tensor with two axes.
value:
The value to prepend or append.
direction:
It can be either "left" or "right". If it is "left", we
prepend the value to the beginning of each sublist;
if it is "right", we append the value to the end of each
sublist.
Returns:
Return a new ragged tensor, whose sublists either start with
or end with the given value.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> concat(a, value=0, direction="left")
[ [ 0 1 3 ] [ 0 5 ] ]
>>> concat(a, value=0, direction="right")
[ [ 1 3 0 ] [ 5 0 ] ]
"""
dtype = ragged.dtype
device = ragged.device
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full(
size=(ragged.tot_size(0), 1),
fill_value=value,
device=device,
dtype=dtype,
)
pad = k2.RaggedTensor(pad_values)
if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1)
elif direction == "right":
ans = k2.ragged.cat([ragged, pad], axis=1)
else:
raise ValueError(
f'Unsupported direction: {direction}. " \
"Expect either "left" or "right"'
)
return ans
def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor:
"""Add SOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
sos_id:
The ID of the SOS symbol.
Returns:
Return a new ragged tensor, where each sublist starts with SOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_sos(a, sos_id=0)
[ [ 0 1 3 ] [ 0 5 ] ]
"""
return concat(ragged, sos_id, direction="left")
def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
"""Add EOS to each sublist.
Args:
ragged:
A ragged tensor with two axes.
eos_id:
The ID of the EOS symbol.
Returns:
Return a new ragged tensor, where each sublist ends with EOS.
>>> a = k2.RaggedTensor([[1, 3], [5]])
>>> a
[ [ 1 3 ] [ 5 ] ]
>>> add_eos(a, eos_id=0)
[ [ 1 3 0 ] [ 5 0 ] ]
"""
return concat(ragged, eos_id, direction="right")
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = lengths.max()
n = lengths.size(0)
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
return expaned_lengths >= lengths.unsqueeze(1)

View File

@ -21,7 +21,14 @@ import pytest
import torch
from icefall.env import get_env_info
from icefall.utils import AttributeDict, encode_supervisions, get_texts
from icefall.utils import (
AttributeDict,
add_eos,
add_sos,
encode_supervisions,
get_texts,
make_pad_mask,
)
@pytest.fixture
@ -126,3 +133,35 @@ def test_attribute_dict():
def test_get_env_info():
s = get_env_info()
print(s)
def test_makd_pad_mask():
lengths = torch.tensor([1, 3, 2])
mask = make_pad_mask(lengths)
expected = torch.tensor(
[
[False, True, True],
[False, False, False],
[False, False, True],
]
)
assert torch.all(torch.eq(mask, expected))
assert (~expected).sum() == lengths.sum()
def test_add_sos():
sos_id = 100
ragged = k2.RaggedTensor([[1, 2], [3], [0]])
sos_ragged = add_sos(ragged, sos_id)
expected = k2.RaggedTensor([[sos_id, 1, 2], [sos_id, 3], [sos_id, 0]])
assert str(sos_ragged) == str(expected)
def test_add_eos():
eos_id = 30
ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]])
ragged_eos = add_eos(ragged, eos_id)
expected = k2.RaggedTensor(
[[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]]
)
assert str(ragged_eos) == str(expected)