mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Increase the size of the context in the RNN-T decoder.
This commit is contained in:
parent
cb04c8a750
commit
04977175a3
@ -130,6 +130,8 @@ def get_params() -> AttributeDict:
|
|||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
|
# parameters for decoder
|
||||||
|
"context_size": 2, # tri-gram
|
||||||
# decoder params
|
# decoder params
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -158,6 +160,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -35,6 +36,7 @@ class Decoder(nn.Module):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
|
context_size: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -44,6 +46,9 @@ class Decoder(nn.Module):
|
|||||||
Dimension of the input embedding.
|
Dimension of the input embedding.
|
||||||
blank_id:
|
blank_id:
|
||||||
The ID of the blank symbol.
|
The ID of the blank symbol.
|
||||||
|
context_size:
|
||||||
|
Number of previous words to use to predict the next word.
|
||||||
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
@ -53,6 +58,18 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
|
assert context_size >= 1, context_size
|
||||||
|
self.context_size = context_size
|
||||||
|
if context_size > 1:
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels=embedding_dim,
|
||||||
|
out_channels=embedding_dim,
|
||||||
|
kernel_size=context_size,
|
||||||
|
padding=0,
|
||||||
|
groups=embedding_dim,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -62,4 +79,16 @@ class Decoder(nn.Module):
|
|||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
embeding_out = self.embedding(y)
|
embeding_out = self.embedding(y)
|
||||||
|
if self.context_size > 1:
|
||||||
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
|
if self.training is True:
|
||||||
|
embeding_out = F.pad(
|
||||||
|
embeding_out, pad=(self.context_size - 1, 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# During inference time, there is no need to do extra padding
|
||||||
|
# as we only need one output
|
||||||
|
assert embeding_out.size(-1) == self.context_size
|
||||||
|
embeding_out = self.conv(embeding_out)
|
||||||
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
return embeding_out
|
return embeding_out
|
||||||
|
61
egs/librispeech/ASR/transducer_stateless/test_decoder.py
Executable file
61
egs/librispeech/ASR/transducer_stateless/test_decoder.py
Executable file
@ -0,0 +1,61 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer_stateless/test_decoder.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from decoder import Decoder
|
||||||
|
|
||||||
|
|
||||||
|
def test_decoder():
|
||||||
|
vocab_size = 3
|
||||||
|
blank_id = 0
|
||||||
|
embedding_dim = 128
|
||||||
|
context_size = 4
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
blank_id=blank_id,
|
||||||
|
context_size=context_size,
|
||||||
|
)
|
||||||
|
N = 100
|
||||||
|
U = 20
|
||||||
|
x = torch.randint(low=0, high=vocab_size, size=(N, U))
|
||||||
|
y = decoder(x)
|
||||||
|
assert y.shape == (N, U, embedding_dim)
|
||||||
|
|
||||||
|
# for inference
|
||||||
|
decoder.eval()
|
||||||
|
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
|
||||||
|
y = decoder(x)
|
||||||
|
assert y.shape == (N, 1, embedding_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_decoder()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -202,6 +202,8 @@ def get_params() -> AttributeDict:
|
|||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
|
# parameters for decoder
|
||||||
|
"context_size": 2, # tri-gram
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
@ -233,6 +235,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
|
context_size=params.context_size,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user