mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add nn.Linear to transform the output of encoder and decoder.
This commit is contained in:
parent
ec78b7ef72
commit
963ac73c27
@ -44,7 +44,8 @@ class Transducer(nn.Module):
|
|||||||
It is the transcription network in the paper. Its accepts
|
It is the transcription network in the paper. Its accepts
|
||||||
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
||||||
It returns two tensors: `logits` of shape (N, T, C) and
|
It returns two tensors: `logits` of shape (N, T, C) and
|
||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,). It should have an attribute:
|
||||||
|
output_dim.
|
||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, C). It should contain
|
is (N, U) and its output shape is (N, U, C). It should contain
|
||||||
@ -70,9 +71,48 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
|
vocab_size = self.joiner.output_dim
|
||||||
|
joiner_dim = self.joiner.input_dim
|
||||||
|
|
||||||
|
# Note: self.joiner.output_dim is equal to vocab_size.
|
||||||
|
# This layer is to transform the decoder output for computing
|
||||||
|
# simple loss
|
||||||
|
self.simple_decoder_linear = nn.Linear(
|
||||||
|
self.decoder.embedding_dim, vocab_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# This layer is to transform the encoder output for computing
|
||||||
|
# simple loss
|
||||||
|
self.simple_encoder_linear = nn.Linear(
|
||||||
|
self.encoder.output_dim, vocab_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform the output of decoder so that it can be added
|
||||||
|
# with the output of encoder in the joiner.
|
||||||
|
self.decoder_linear = nn.Linear(vocab_size, joiner_dim)
|
||||||
|
|
||||||
|
# Transform the output of encoder so that it can be added
|
||||||
|
# with the output of decoder in the joiner
|
||||||
|
self.encoder_linear = nn.Linear(vocab_size, joiner_dim)
|
||||||
|
|
||||||
self.decoder_giga = decoder_giga
|
self.decoder_giga = decoder_giga
|
||||||
self.joiner_giga = joiner_giga
|
self.joiner_giga = joiner_giga
|
||||||
|
|
||||||
|
if decoder_giga is not None:
|
||||||
|
self.simple_decoder_giga_linear = nn.Linear(
|
||||||
|
self.decoder.embedding_dim, vocab_size
|
||||||
|
)
|
||||||
|
self.simple_encoder_giga_linear = nn.Linear(
|
||||||
|
self.encoder.output_dim, vocab_size
|
||||||
|
)
|
||||||
|
self.decoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
|
||||||
|
self.encoder_giga_linear = nn.Linear(vocab_size, joiner_dim)
|
||||||
|
else:
|
||||||
|
self.simple_decoder_giga_linear = None
|
||||||
|
self.simple_encoder_giga_linear = None
|
||||||
|
self.decoder_giga_linear = None
|
||||||
|
self.encoder_giga_linear = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -136,9 +176,17 @@ class Transducer(nn.Module):
|
|||||||
if libri:
|
if libri:
|
||||||
decoder = self.decoder
|
decoder = self.decoder
|
||||||
joiner = self.joiner
|
joiner = self.joiner
|
||||||
|
simple_decoder_linear = self.simple_decoder_linear
|
||||||
|
simple_encoder_linear = self.simple_encoder_linear
|
||||||
|
decoder_linear = self.decoder_linear
|
||||||
|
encoder_linear = self.encoder_linear
|
||||||
else:
|
else:
|
||||||
decoder = self.decoder_giga
|
decoder = self.decoder_giga
|
||||||
joiner = self.joiner_giga
|
joiner = self.joiner_giga
|
||||||
|
simple_decoder_linear = self.simple_decoder_giga_linear
|
||||||
|
simple_encoder_linear = self.simple_encoder_giga_linear
|
||||||
|
decoder_linear = self.decoder_giga_linear
|
||||||
|
encoder_linear = self.encoder_giga_linear
|
||||||
|
|
||||||
# decoder_out: [B, S + 1, C]
|
# decoder_out: [B, S + 1, C]
|
||||||
decoder_out = decoder(sos_y_padded)
|
decoder_out = decoder(sos_y_padded)
|
||||||
@ -154,9 +202,12 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
|
simple_decoder_out = simple_decoder_linear(decoder_out)
|
||||||
|
simple_encoder_out = simple_encoder_linear(encoder_out)
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=decoder_out,
|
lm=simple_decoder_out,
|
||||||
am=encoder_out,
|
am=simple_encoder_out,
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
@ -177,9 +228,12 @@ class Transducer(nn.Module):
|
|||||||
# am_pruned : [B, T, prune_range, C]
|
# am_pruned : [B, T, prune_range, C]
|
||||||
# lm_pruned : [B, T, prune_range, C]
|
# lm_pruned : [B, T, prune_range, C]
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
am=simple_encoder_out, lm=simple_decoder_out, ranges=ranges
|
||||||
)
|
)
|
||||||
|
|
||||||
|
am_pruned = encoder_linear(am_pruned)
|
||||||
|
lm_pruned = decoder_linear(lm_pruned)
|
||||||
|
|
||||||
# logits : [B, T, prune_range, C]
|
# logits : [B, T, prune_range, C]
|
||||||
logits = joiner(am_pruned, lm_pruned)
|
logits = joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
|
83
egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py
Executable file
83
egs/librispeech/ASR/pruned_transducer_stateless_multi_datasets/test_model.py
Executable file
@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless_multi_datasets/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from conformer import Conformer
|
||||||
|
from decoder import Decoder
|
||||||
|
from joiner import Joiner
|
||||||
|
from model import Transducer
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
# encoder params
|
||||||
|
input_dim = 10
|
||||||
|
attention_dim = 512
|
||||||
|
|
||||||
|
# decoder params
|
||||||
|
vocab_size = 3
|
||||||
|
embedding_dim = 512
|
||||||
|
blank_id = 0
|
||||||
|
context_size = 2
|
||||||
|
|
||||||
|
joiner_dim = 1024
|
||||||
|
|
||||||
|
encoder = Conformer(
|
||||||
|
num_features=input_dim,
|
||||||
|
subsampling_factor=4,
|
||||||
|
d_model=attention_dim,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
num_encoder_layers=12,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
blank_id=blank_id,
|
||||||
|
context_size=context_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = Joiner(joiner_dim, vocab_size)
|
||||||
|
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||||
|
|
||||||
|
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
|
||||||
|
N = y.dim0
|
||||||
|
T = 50
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_dim)
|
||||||
|
x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32)
|
||||||
|
x_lens[0] = T
|
||||||
|
|
||||||
|
loss = transducer(x, x_lens, y)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -290,6 +290,8 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
|
"decoder_embedding_dim": 512,
|
||||||
|
"joiner_dim": 1024, # input dim of the joiner
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
@ -320,7 +322,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.attention_dim,
|
embedding_dim=params.decoder_embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
@ -329,7 +331,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.attention_dim,
|
input_dim=params.joiner_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
@ -99,6 +99,7 @@ class Transformer(EncoderInterface):
|
|||||||
num_layers=num_encoder_layers,
|
num_layers=num_encoder_layers,
|
||||||
norm=encoder_norm,
|
norm=encoder_norm,
|
||||||
)
|
)
|
||||||
|
self.output_dim = d_model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
@ -58,6 +58,7 @@ class Decoder(nn.Module):
|
|||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user