Add nn.Linear to transform the output of encoder and decoder.

This commit is contained in:
Fangjun Kuang 2022-03-11 15:41:03 +08:00
parent ec78b7ef72
commit 963ac73c27
5 changed files with 147 additions and 6 deletions

View File

@ -44,7 +44,8 @@ class Transducer(nn.Module):
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
`logit_lens` of shape (N,). It should have an attribute:
output_dim.
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
@ -70,9 +71,48 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
vocab_size = self.joiner.output_dim
joiner_dim = self.joiner.input_dim
# Note: self.joiner.output_dim is equal to vocab_size.
# This layer is to transform the decoder output for computing
# simple loss
self.simple_decoder_linear = nn.Linear(
self.decoder.embedding_dim, vocab_size
)
# This layer is to transform the encoder output for computing
# simple loss
self.simple_encoder_linear = nn.Linear(
self.encoder.output_dim, vocab_size
)
# Transform the output of decoder so that it can be added
# with the output of encoder in the joiner.
self.decoder_linear = nn.Linear(vocab_size, joiner_dim)
# Transform the output of encoder so that it can be added
# with the output of decoder in the joiner
self.encoder_linear = nn.Linear(vocab_size, joiner_dim)
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
if decoder_giga is not None:
self.simple_decoder_giga_linear = nn.Linear(
self.decoder.embedding_dim, vocab_size
)
self.simple_encoder_giga_linear = nn.Linear(
self.encoder.output_dim, vocab_size
)
self.decoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
self.encoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
else:
self.simple_decoder_giga_linear = None
self.simple_encoder_giga_linear = None
self.decoder_giga_linear = None
self.encoder_giga_linear = None
def forward(
self,
x: torch.Tensor,
@ -136,9 +176,17 @@ class Transducer(nn.Module):
if libri:
decoder = self.decoder
joiner = self.joiner
simple_decoder_linear = self.simple_decoder_linear
simple_encoder_linear = self.simple_encoder_linear
decoder_linear = self.decoder_linear
encoder_linear = self.encoder_linear
else:
decoder = self.decoder_giga
joiner = self.joiner_giga
simple_decoder_linear = self.simple_decoder_giga_linear
simple_encoder_linear = self.simple_encoder_giga_linear
decoder_linear = self.decoder_giga_linear
encoder_linear = self.encoder_giga_linear
# decoder_out: [B, S + 1, C]
decoder_out = decoder(sos_y_padded)
@ -154,9 +202,12 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_decoder_out = simple_decoder_linear(decoder_out)
simple_encoder_out = simple_encoder_linear(encoder_out)
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
lm=simple_decoder_out,
am=simple_encoder_out,
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
@ -177,9 +228,12 @@ class Transducer(nn.Module):
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
am=simple_encoder_out, lm=simple_decoder_out, ranges=ranges
)
am_pruned = encoder_linear(am_pruned)
lm_pruned = decoder_linear(lm_pruned)
# logits : [B, T, prune_range, C]
logits = joiner(am_pruned, lm_pruned)

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless_multi_datasets/test_model.py
"""
import k2
import torch
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
def test_model():
# encoder params
input_dim = 10
attention_dim = 512
# decoder params
vocab_size = 3
embedding_dim = 512
blank_id = 0
context_size = 2
joiner_dim = 1024
encoder = Conformer(
num_features=input_dim,
subsampling_factor=4,
d_model=attention_dim,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
joiner = Joiner(joiner_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0
T = 50
x = torch.rand(N, T, input_dim)
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
x_lens[0] = T
loss = transducer(x, x_lens, y)
print(loss)
def main():
test_model()
if __name__ == "__main__":
main()

View File

@ -290,6 +290,8 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"decoder_embedding_dim": 512,
"joiner_dim": 1024, # input dim of the joiner
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
@ -320,7 +322,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.attention_dim,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
@ -329,7 +331,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.attention_dim,
input_dim=params.joiner_dim,
output_dim=params.vocab_size,
)
return joiner

View File

@ -99,6 +99,7 @@ class Transformer(EncoderInterface):
num_layers=num_encoder_layers,
norm=encoder_norm,
)
self.output_dim = d_model
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.embedding_dim = embedding_dim
assert context_size >= 1, context_size
self.context_size = context_size