mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Add nn.Linear to transform the output of encoder and decoder.
This commit is contained in:
parent
ec78b7ef72
commit
963ac73c27
@ -44,7 +44,8 @@ class Transducer(nn.Module):
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, C) and
|
||||
`logit_lens` of shape (N,).
|
||||
`logit_lens` of shape (N,). It should have an attribute:
|
||||
output_dim.
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
@ -70,9 +71,48 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
vocab_size = self.joiner.output_dim
|
||||
joiner_dim = self.joiner.input_dim
|
||||
|
||||
# Note: self.joiner.output_dim is equal to vocab_size.
|
||||
# This layer is to transform the decoder output for computing
|
||||
# simple loss
|
||||
self.simple_decoder_linear = nn.Linear(
|
||||
self.decoder.embedding_dim, vocab_size
|
||||
)
|
||||
|
||||
# This layer is to transform the encoder output for computing
|
||||
# simple loss
|
||||
self.simple_encoder_linear = nn.Linear(
|
||||
self.encoder.output_dim, vocab_size
|
||||
)
|
||||
|
||||
# Transform the output of decoder so that it can be added
|
||||
# with the output of encoder in the joiner.
|
||||
self.decoder_linear = nn.Linear(vocab_size, joiner_dim)
|
||||
|
||||
# Transform the output of encoder so that it can be added
|
||||
# with the output of decoder in the joiner
|
||||
self.encoder_linear = nn.Linear(vocab_size, joiner_dim)
|
||||
|
||||
self.decoder_giga = decoder_giga
|
||||
self.joiner_giga = joiner_giga
|
||||
|
||||
if decoder_giga is not None:
|
||||
self.simple_decoder_giga_linear = nn.Linear(
|
||||
self.decoder.embedding_dim, vocab_size
|
||||
)
|
||||
self.simple_encoder_giga_linear = nn.Linear(
|
||||
self.encoder.output_dim, vocab_size
|
||||
)
|
||||
self.decoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
|
||||
self.encoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
|
||||
else:
|
||||
self.simple_decoder_giga_linear = None
|
||||
self.simple_encoder_giga_linear = None
|
||||
self.decoder_giga_linear = None
|
||||
self.encoder_giga_linear = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -136,9 +176,17 @@ class Transducer(nn.Module):
|
||||
if libri:
|
||||
decoder = self.decoder
|
||||
joiner = self.joiner
|
||||
simple_decoder_linear = self.simple_decoder_linear
|
||||
simple_encoder_linear = self.simple_encoder_linear
|
||||
decoder_linear = self.decoder_linear
|
||||
encoder_linear = self.encoder_linear
|
||||
else:
|
||||
decoder = self.decoder_giga
|
||||
joiner = self.joiner_giga
|
||||
simple_decoder_linear = self.simple_decoder_giga_linear
|
||||
simple_encoder_linear = self.simple_encoder_giga_linear
|
||||
decoder_linear = self.decoder_giga_linear
|
||||
encoder_linear = self.encoder_giga_linear
|
||||
|
||||
# decoder_out: [B, S + 1, C]
|
||||
decoder_out = decoder(sos_y_padded)
|
||||
@ -154,9 +202,12 @@ class Transducer(nn.Module):
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
simple_decoder_out = simple_decoder_linear(decoder_out)
|
||||
simple_encoder_out = simple_encoder_linear(encoder_out)
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=decoder_out,
|
||||
am=encoder_out,
|
||||
lm=simple_decoder_out,
|
||||
am=simple_encoder_out,
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
@ -177,9 +228,12 @@ class Transducer(nn.Module):
|
||||
# am_pruned : [B, T, prune_range, C]
|
||||
# lm_pruned : [B, T, prune_range, C]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
||||
am=simple_encoder_out, lm=simple_decoder_out, ranges=ranges
|
||||
)
|
||||
|
||||
am_pruned = encoder_linear(am_pruned)
|
||||
lm_pruned = decoder_linear(lm_pruned)
|
||||
|
||||
# logits : [B, T, prune_range, C]
|
||||
logits = joiner(am_pruned, lm_pruned)
|
||||
|
||||
|
83
egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py
Executable file
83
egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py
Executable file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless_multi_datasets/test_model.py
|
||||
"""
|
||||
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
|
||||
|
||||
def test_model():
|
||||
# encoder params
|
||||
input_dim = 10
|
||||
attention_dim = 512
|
||||
|
||||
# decoder params
|
||||
vocab_size = 3
|
||||
embedding_dim = 512
|
||||
blank_id = 0
|
||||
context_size = 2
|
||||
|
||||
joiner_dim = 1024
|
||||
|
||||
encoder = Conformer(
|
||||
num_features=input_dim,
|
||||
subsampling_factor=4,
|
||||
d_model=attention_dim,
|
||||
nhead=8,
|
||||
dim_feedforward=2048,
|
||||
num_encoder_layers=12,
|
||||
)
|
||||
|
||||
decoder = Decoder(
|
||||
vocab_size=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
blank_id=blank_id,
|
||||
context_size=context_size,
|
||||
)
|
||||
|
||||
joiner = Joiner(joiner_dim, vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||
|
||||
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
|
||||
N = y.dim0
|
||||
T = 50
|
||||
|
||||
x = torch.rand(N, T, input_dim)
|
||||
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
|
||||
x_lens[0] = T
|
||||
|
||||
loss = transducer(x, x_lens, y)
|
||||
print(loss)
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -290,6 +290,8 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"decoder_embedding_dim": 512,
|
||||
"joiner_dim": 1024, # input dim of the joiner
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
@ -320,7 +322,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.attention_dim,
|
||||
embedding_dim=params.decoder_embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
@ -329,7 +331,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.attention_dim,
|
||||
input_dim=params.joiner_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
@ -99,6 +99,7 @@ class Transformer(EncoderInterface):
|
||||
num_layers=num_encoder_layers,
|
||||
norm=encoder_norm,
|
||||
)
|
||||
self.output_dim = d_model
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
|
@ -58,6 +58,7 @@ class Decoder(nn.Module):
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
|
Loading…
x
Reference in New Issue
Block a user