mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Begin to add RNN-T training for librispeech.
This commit is contained in:
parent
95af039733
commit
5802d5ad2e
98
egs/librispeech/ASR/transducer/decoder.py
Normal file
98
egs/librispeech/ASR/transducer/decoder.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# TODO(fangjun): Support switching between LSTM and GRU
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
embedding_dim: int,
|
||||
blank_id: int,
|
||||
sos_id: int,
|
||||
num_layers: int,
|
||||
hidden_dim: int,
|
||||
embedding_dropout: float = 0.0,
|
||||
rnn_dropout: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
embedding_dim:
|
||||
Dimension of the input embedding.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
sos_id:
|
||||
The ID of the SOS symbol.
|
||||
num_layers:
|
||||
Number of LSTM layers.
|
||||
hidden_dim:
|
||||
Hidden dimension of LSTM layers.
|
||||
embedding_dropout:
|
||||
Dropout rate for the embedding layer.
|
||||
rnn_dropout:
|
||||
Dropout for LSTM layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.embedding_dropout = nn.Dropout(embedding_dropout)
|
||||
# TODO(fangjun): Use layer normalized LSTM
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=embedding_dim,
|
||||
hidden_size=hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=rnn_dropout,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.sos_id = sos_id
|
||||
self.output_linear = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U) with BOS prepended.
|
||||
states:
|
||||
A tuple of two tensors containing the states information of
|
||||
LSTM layers in this decoder.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
|
||||
- rnn_output, a tensor of shape (N, U, C)
|
||||
- (h, c), containing the state information for LSTM layers.
|
||||
Both are of shape (num_layers, N, C)
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embeding_out = self.embedding_dropout(embeding_out)
|
||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
||||
out = self.output_linear(rnn_out)
|
||||
|
||||
return out, (h, c)
|
53
egs/librispeech/ASR/transducer/encoder_interface.py
Normal file
53
egs/librispeech/ASR/transducer/encoder_interface.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def __init__(self, num_features: int, output_dim: int):
|
||||
"""
|
||||
Args:
|
||||
num_features:
|
||||
The dimension of the input features.
|
||||
output_dim:
|
||||
Output dimension of the model.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
55
egs/librispeech/ASR/transducer/joiner.py
Normal file
55
egs/librispeech/ASR/transducer/joiner.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == decoder_out.size(2)
|
||||
|
||||
encoder_out = encoder_out.unsqueeze(2)
|
||||
# Now encoder_out is (N, T, 1, C)
|
||||
|
||||
decoder_out = decoder_out.unsqueeze(1)
|
||||
# Now decoder_out is (N, 1, U, C)
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = F.relu(logit)
|
||||
|
||||
output = self.output_linear(logit)
|
||||
|
||||
return output
|
127
egs/librispeech/ASR/transducer/model.py
Normal file
127
egs/librispeech/ASR/transducer/model.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
||||
"""
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import torchaudio.functional
|
||||
from transducer.encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||
"Please install a version >= 0.10.0"
|
||||
)
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, C) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
two attributes: `blank_id` and `sos_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert hasattr(decoder, "sos_id")
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_id = self.decoder.sos_id
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.joiner(encoder_out, decoder_out)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
# Note: y does not start with SOS
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
loss = torchaudio.functional.rnnt_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
return loss
|
67
egs/librispeech/ASR/transducer/test_decoder.py
Executable file
67
egs/librispeech/ASR/transducer/test_decoder.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/yesno/ASR
|
||||
python ./transducer/test_decoder.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transducer.decoder import Decoder
|
||||
|
||||
|
||||
def test_decoder():
|
||||
vocab_size = 3
|
||||
blank_id = 0
|
||||
sos_id = 2
|
||||
embedding_dim = 128
|
||||
num_layers = 2
|
||||
hidden_dim = 6
|
||||
N = 3
|
||||
U = 5
|
||||
|
||||
decoder = Decoder(
|
||||
vocab_size=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
blank_id=blank_id,
|
||||
sos_id=sos_id,
|
||||
num_layers=num_layers,
|
||||
hidden_dim=hidden_dim,
|
||||
embedding_dropout=0.0,
|
||||
rnn_dropout=0.0,
|
||||
)
|
||||
x = torch.randint(1, vocab_size, (N, U))
|
||||
rnn_out, (h, c) = decoder(x)
|
||||
|
||||
assert rnn_out.shape == (N, U, hidden_dim)
|
||||
assert h.shape == (num_layers, N, hidden_dim)
|
||||
assert c.shape == (num_layers, N, hidden_dim)
|
||||
|
||||
rnn_out, (h, c) = decoder(x, (h, c))
|
||||
assert rnn_out.shape == (N, U, hidden_dim)
|
||||
assert h.shape == (num_layers, N, hidden_dim)
|
||||
assert c.shape == (num_layers, N, hidden_dim)
|
||||
|
||||
|
||||
def main():
|
||||
test_decoder()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
50
egs/librispeech/ASR/transducer/test_joiner.py
Executable file
50
egs/librispeech/ASR/transducer/test_joiner.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/yesno/ASR
|
||||
python ./transducer/test_joiner.py
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from transducer.joiner import Joiner
|
||||
|
||||
|
||||
def test_joiner():
|
||||
N = 2
|
||||
T = 3
|
||||
C = 4
|
||||
U = 5
|
||||
|
||||
joiner = Joiner(C, 10)
|
||||
|
||||
encoder_out = torch.rand(N, T, C)
|
||||
decoder_out = torch.rand(N, U, C)
|
||||
|
||||
joint = joiner(encoder_out, decoder_out)
|
||||
assert joint.shape == (N, T, U, 10)
|
||||
|
||||
|
||||
def main():
|
||||
test_joiner()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user