Add nn.Linear to transform the output of encoder and decoder.

This commit is contained in:
Fangjun Kuang 2022-03-11 15:41:03 +08:00
parent ec78b7ef72
commit 963ac73c27
5 changed files with 147 additions and 6 deletions

View File

@ -44,7 +44,8 @@ class Transducer(nn.Module):
It is the transcription network in the paper. Its accepts It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,). `logit_lens` of shape (N,). It should have an attribute:
output_dim.
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
@ -70,9 +71,48 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
vocab_size = self.joiner.output_dim
joiner_dim = self.joiner.input_dim
# Note: self.joiner.output_dim is equal to vocab_size.
# This layer is to transform the decoder output for computing
# simple loss
self.simple_decoder_linear = nn.Linear(
self.decoder.embedding_dim, vocab_size
)
# This layer is to transform the encoder output for computing
# simple loss
self.simple_encoder_linear = nn.Linear(
self.encoder.output_dim, vocab_size
)
# Transform the output of decoder so that it can be added
# with the output of encoder in the joiner.
self.decoder_linear = nn.Linear(vocab_size, joiner_dim)
# Transform the output of encoder so that it can be added
# with the output of decoder in the joiner
self.encoder_linear = nn.Linear(vocab_size, joiner_dim)
self.decoder_giga = decoder_giga self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga self.joiner_giga = joiner_giga
if decoder_giga is not None:
self.simple_decoder_giga_linear = nn.Linear(
self.decoder.embedding_dim, vocab_size
)
self.simple_encoder_giga_linear = nn.Linear(
self.encoder.output_dim, vocab_size
)
self.decoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
self.encoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
else:
self.simple_decoder_giga_linear = None
self.simple_encoder_giga_linear = None
self.decoder_giga_linear = None
self.encoder_giga_linear = None
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -136,9 +176,17 @@ class Transducer(nn.Module):
if libri: if libri:
decoder = self.decoder decoder = self.decoder
joiner = self.joiner joiner = self.joiner
simple_decoder_linear = self.simple_decoder_linear
simple_encoder_linear = self.simple_encoder_linear
decoder_linear = self.decoder_linear
encoder_linear = self.encoder_linear
else: else:
decoder = self.decoder_giga decoder = self.decoder_giga
joiner = self.joiner_giga joiner = self.joiner_giga
simple_decoder_linear = self.simple_decoder_giga_linear
simple_encoder_linear = self.simple_encoder_giga_linear
decoder_linear = self.decoder_giga_linear
encoder_linear = self.encoder_giga_linear
# decoder_out: [B, S + 1, C] # decoder_out: [B, S + 1, C]
decoder_out = decoder(sos_y_padded) decoder_out = decoder(sos_y_padded)
@ -154,9 +202,12 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
simple_decoder_out = simple_decoder_linear(decoder_out)
simple_encoder_out = simple_encoder_linear(encoder_out)
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out, lm=simple_decoder_out,
am=encoder_out, am=simple_encoder_out,
symbols=y_padded, symbols=y_padded,
termination_symbol=blank_id, termination_symbol=blank_id,
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
@ -177,9 +228,12 @@ class Transducer(nn.Module):
# am_pruned : [B, T, prune_range, C] # am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C] # lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges am=simple_encoder_out, lm=simple_decoder_out, ranges=ranges
) )
am_pruned = encoder_linear(am_pruned)
lm_pruned = decoder_linear(lm_pruned)
# logits : [B, T, prune_range, C] # logits : [B, T, prune_range, C]
logits = joiner(am_pruned, lm_pruned) logits = joiner(am_pruned, lm_pruned)

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless_multi_datasets/test_model.py
"""
import k2
import torch
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
def test_model():
# encoder params
input_dim = 10
attention_dim = 512
# decoder params
vocab_size = 3
embedding_dim = 512
blank_id = 0
context_size = 2
joiner_dim = 1024
encoder = Conformer(
num_features=input_dim,
subsampling_factor=4,
d_model=attention_dim,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
joiner = Joiner(joiner_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0
T = 50
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
x_lens[0] = T
loss = transducer(x, x_lens, y)
print(loss)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -290,6 +290,8 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512, "attention_dim": 512,
"decoder_embedding_dim": 512,
"joiner_dim": 1024, # input dim of the joiner
"nhead": 8, "nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
@ -320,7 +322,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.attention_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
) )
@ -329,7 +331,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.attention_dim, input_dim=params.joiner_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
) )
return joiner return joiner

View File

@ -99,6 +99,7 @@ class Transformer(EncoderInterface):
num_layers=num_encoder_layers, num_layers=num_encoder_layers,
norm=encoder_norm, norm=encoder_norm,
) )
self.output_dim = d_model
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id, padding_idx=blank_id,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.embedding_dim = embedding_dim
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size