diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py index 1e716e2ab..e816d5233 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/model.py @@ -44,7 +44,8 @@ class Transducer(nn.Module): It is the transcription network in the paper. Its accepts two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). It returns two tensors: `logits` of shape (N, T, C) and - `logit_lens` of shape (N,). + `logit_lens` of shape (N,). It should have an attribute: + output_dim. decoder: It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain @@ -70,9 +71,48 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner + vocab_size = self.joiner.output_dim + joiner_dim = self.joiner.input_dim + + # Note: self.joiner.output_dim is equal to vocab_size. + # This layer is to transform the decoder output for computing + # simple loss + self.simple_decoder_linear = nn.Linear( + self.decoder.embedding_dim, vocab_size + ) + + # This layer is to transform the encoder output for computing + # simple loss + self.simple_encoder_linear = nn.Linear( + self.encoder.output_dim, vocab_size + ) + + # Transform the output of decoder so that it can be added + # with the output of encoder in the joiner. + self.decoder_linear = nn.Linear(vocab_size, joiner_dim) + + # Transform the output of encoder so that it can be added + # with the output of decoder in the joiner + self.encoder_linear = nn.Linear(vocab_size, joiner_dim) + self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga + if decoder_giga is not None: + self.simple_decoder_giga_linear = nn.Linear( + self.decoder.embedding_dim, vocab_size + ) + self.simple_encoder_giga_linear = nn.Linear( + self.encoder.output_dim, vocab_size + ) + self.decoder_giga_linear = nn.Linear(vocab_size, joiner_dim) + self.encoder_giga_linear = nn.Linear(vocab_size, joiner_dim) + else: + self.simple_decoder_giga_linear = None + self.simple_encoder_giga_linear = None + self.decoder_giga_linear = None + self.encoder_giga_linear = None + def forward( self, x: torch.Tensor, @@ -136,9 +176,17 @@ class Transducer(nn.Module): if libri: decoder = self.decoder joiner = self.joiner + simple_decoder_linear = self.simple_decoder_linear + simple_encoder_linear = self.simple_encoder_linear + decoder_linear = self.decoder_linear + encoder_linear = self.encoder_linear else: decoder = self.decoder_giga joiner = self.joiner_giga + simple_decoder_linear = self.simple_decoder_giga_linear + simple_encoder_linear = self.simple_encoder_giga_linear + decoder_linear = self.decoder_giga_linear + encoder_linear = self.encoder_giga_linear # decoder_out: [B, S + 1, C] decoder_out = decoder(sos_y_padded) @@ -154,9 +202,12 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens + simple_decoder_out = simple_decoder_linear(decoder_out) + simple_encoder_out = simple_encoder_linear(encoder_out) + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=decoder_out, - am=encoder_out, + lm=simple_decoder_out, + am=simple_encoder_out, symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale, @@ -177,9 +228,12 @@ class Transducer(nn.Module): # am_pruned : [B, T, prune_range, C] # lm_pruned : [B, T, prune_range, C] am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=encoder_out, lm=decoder_out, ranges=ranges + am=simple_encoder_out, lm=simple_decoder_out, ranges=ranges ) + am_pruned = encoder_linear(am_pruned) + lm_pruned = decoder_linear(lm_pruned) + # logits : [B, T, prune_range, C] logits = joiner(am_pruned, lm_pruned) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py new file mode 100755 index 000000000..0ffed60db --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless_multi_datasets/test_model.py +""" + + +import k2 +import torch +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer + + +def test_model(): + # encoder params + input_dim = 10 + attention_dim = 512 + + # decoder params + vocab_size = 3 + embedding_dim = 512 + blank_id = 0 + context_size = 2 + + joiner_dim = 1024 + + encoder = Conformer( + num_features=input_dim, + subsampling_factor=4, + d_model=attention_dim, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + ) + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + context_size=context_size, + ) + + joiner = Joiner(joiner_dim, vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]]) + N = y.dim0 + T = 50 + + x = torch.rand(N, T, input_dim) + x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32) + x_lens[0] = T + + loss = transducer(x, x_lens, y) + print(loss) + + +def main(): + test_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py index 425879321..c03291113 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/train.py @@ -290,6 +290,8 @@ def get_params() -> AttributeDict: "feature_dim": 80, "subsampling_factor": 4, "attention_dim": 512, + "decoder_embedding_dim": 512, + "joiner_dim": 1024, # input dim of the joiner "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, @@ -320,7 +322,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.attention_dim, + embedding_dim=params.decoder_embedding_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -329,7 +331,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.attention_dim, + input_dim=params.joiner_dim, output_dim=params.vocab_size, ) return joiner diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py index bce92f8f6..48f529f59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/transformer.py @@ -99,6 +99,7 @@ class Transformer(EncoderInterface): num_layers=num_encoder_layers, norm=encoder_norm, ) + self.output_dim = d_model def forward( self, x: torch.Tensor, x_lens: torch.Tensor diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index b82fed37b..923cee106 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -58,6 +58,7 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.blank_id = blank_id + self.embedding_dim = embedding_dim assert context_size >= 1, context_size self.context_size = context_size