Rename Jointer to Joiner.

This commit is contained in:
Fangjun Kuang 2021-12-07 10:37:48 +08:00
parent 8038d13ec5
commit b86f45e217
8 changed files with 23 additions and 21 deletions

2
.gitignore vendored
View File

@ -6,3 +6,5 @@ exp
exp*/
*.pt
download
*.bak
*-bak

View File

@ -49,7 +49,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.jointer(current_encoder_out, decoder_out)
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1)

View File

@ -26,7 +26,7 @@ from asr_datamodule import YesNoAsrDataModule
from transducer.beam_search import greedy_search
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.jointer import Jointer
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -248,8 +248,8 @@ def get_transducer_model(params: AttributeDict):
embedding_dropout=0.4,
rnn_dropout=0.4,
)
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer

View File

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.functional as F
class Jointer(nn.Module):
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()

View File

@ -41,7 +41,7 @@ class Transducer(nn.Module):
self,
encoder: nn.Module,
decoder: nn.Module,
jointer: nn.Module,
joiner: nn.Module,
):
"""
Args:
@ -54,7 +54,7 @@ class Transducer(nn.Module):
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
jointer:
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
@ -62,7 +62,7 @@ class Transducer(nn.Module):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.jointer = jointer
self.joiner = joiner
def forward(
self,
@ -103,7 +103,7 @@ class Transducer(nn.Module):
decoder_out, _ = self.decoder(sos_y_padded)
logits = self.jointer(encoder_out, decoder_out)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
y_padded = y.pad(mode="constant", padding_value=0)

View File

@ -19,31 +19,31 @@
To run this file, do:
cd icefall/egs/yesno/ASR
python ./transducer/test_jointer.py
python ./transducer/test_joiner.py
"""
import torch
from transducer.jointer import Jointer
from transducer.joiner import Joiner
def test_jointer():
def test_joiner():
N = 2
T = 3
C = 4
U = 5
jointer = Jointer(C, 10)
joiner = Joiner(C, 10)
encoder_out = torch.rand(N, T, C)
decoder_out = torch.rand(N, U, C)
joint = jointer(encoder_out, decoder_out)
joint = joiner(encoder_out, decoder_out)
assert joint.shape == (N, T, U, 10)
def main():
test_jointer()
test_joiner()
if __name__ == "__main__":

View File

@ -27,7 +27,7 @@ import k2
import torch
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.jointer import Jointer
from transducer.joiner import Joiner
from transducer.model import Transducer
@ -54,8 +54,8 @@ def test_transducer():
rnn_dropout=0.0,
)
jointer = Jointer(output_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
joiner = Joiner(output_dim, vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
N = y.dim0

View File

@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder
from transducer.encoder import Tdnn
from transducer.jointer import Jointer
from transducer.joiner import Joiner
from transducer.model import Transducer
from icefall.checkpoint import load_checkpoint
@ -465,8 +465,8 @@ def get_transducer_model(params: AttributeDict):
embedding_dropout=0.4,
rnn_dropout=0.4,
)
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
return transducer