Rename Jointer to Joiner.

This commit is contained in:
Fangjun Kuang 2021-12-07 10:37:48 +08:00
parent 8038d13ec5
commit b86f45e217
8 changed files with 23 additions and 21 deletions

2
.gitignore vendored
View File

@ -6,3 +6,5 @@ exp
exp*/ exp*/
*.pt *.pt
download download
*.bak
*-bak

View File

@ -49,7 +49,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
logits = model.jointer(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1) # log_prob is (N, 1, 1)

View File

@ -26,7 +26,7 @@ from asr_datamodule import YesNoAsrDataModule
from transducer.beam_search import greedy_search from transducer.beam_search import greedy_search
from transducer.decoder import Decoder from transducer.decoder import Decoder
from transducer.encoder import Tdnn from transducer.encoder import Tdnn
from transducer.jointer import Jointer from transducer.joiner import Joiner
from transducer.model import Transducer from transducer.model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -248,8 +248,8 @@ def get_transducer_model(params: AttributeDict):
embedding_dropout=0.4, embedding_dropout=0.4,
rnn_dropout=0.4, rnn_dropout=0.4,
) )
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size) joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer) transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer return transducer

View File

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class Jointer(nn.Module): class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int): def __init__(self, input_dim: int, output_dim: int):
super().__init__() super().__init__()

View File

@ -41,7 +41,7 @@ class Transducer(nn.Module):
self, self,
encoder: nn.Module, encoder: nn.Module,
decoder: nn.Module, decoder: nn.Module,
jointer: nn.Module, joiner: nn.Module,
): ):
""" """
Args: Args:
@ -54,7 +54,7 @@ class Transducer(nn.Module):
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`. one attribute: `blank_id`.
jointer: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
@ -62,7 +62,7 @@ class Transducer(nn.Module):
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.jointer = jointer self.joiner = joiner
def forward( def forward(
self, self,
@ -103,7 +103,7 @@ class Transducer(nn.Module):
decoder_out, _ = self.decoder(sos_y_padded) decoder_out, _ = self.decoder(sos_y_padded)
logits = self.jointer(encoder_out, decoder_out) logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets # rnnt_loss requires 0 padded targets
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)

View File

@ -19,31 +19,31 @@
To run this file, do: To run this file, do:
cd icefall/egs/yesno/ASR cd icefall/egs/yesno/ASR
python ./transducer/test_jointer.py python ./transducer/test_joiner.py
""" """
import torch import torch
from transducer.jointer import Jointer from transducer.joiner import Joiner
def test_jointer(): def test_joiner():
N = 2 N = 2
T = 3 T = 3
C = 4 C = 4
U = 5 U = 5
jointer = Jointer(C, 10) joiner = Joiner(C, 10)
encoder_out = torch.rand(N, T, C) encoder_out = torch.rand(N, T, C)
decoder_out = torch.rand(N, U, C) decoder_out = torch.rand(N, U, C)
joint = jointer(encoder_out, decoder_out) joint = joiner(encoder_out, decoder_out)
assert joint.shape == (N, T, U, 10) assert joint.shape == (N, T, U, 10)
def main(): def main():
test_jointer() test_joiner()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,7 +27,7 @@ import k2
import torch import torch
from transducer.decoder import Decoder from transducer.decoder import Decoder
from transducer.encoder import Tdnn from transducer.encoder import Tdnn
from transducer.jointer import Jointer from transducer.joiner import Joiner
from transducer.model import Transducer from transducer.model import Transducer
@ -54,8 +54,8 @@ def test_transducer():
rnn_dropout=0.0, rnn_dropout=0.0,
) )
jointer = Jointer(output_dim, vocab_size) joiner = Joiner(output_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer) transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]]) y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0 N = y.dim0

View File

@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder from transducer.decoder import Decoder
from transducer.encoder import Tdnn from transducer.encoder import Tdnn
from transducer.jointer import Jointer from transducer.joiner import Joiner
from transducer.model import Transducer from transducer.model import Transducer
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
@ -465,8 +465,8 @@ def get_transducer_model(params: AttributeDict):
embedding_dropout=0.4, embedding_dropout=0.4,
rnn_dropout=0.4, rnn_dropout=0.4,
) )
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size) joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer) transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer return transducer