From b86f45e217270781f8ce72b79b459897fa055844 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Dec 2021 10:37:48 +0800 Subject: [PATCH] Rename Jointer to Joiner. --- .gitignore | 2 ++ egs/yesno/ASR/transducer/beam_search.py | 2 +- egs/yesno/ASR/transducer/decode.py | 6 +++--- egs/yesno/ASR/transducer/{jointer.py => joiner.py} | 2 +- egs/yesno/ASR/transducer/model.py | 8 ++++---- .../transducer/{test_jointer.py => test_joiner.py} | 12 ++++++------ egs/yesno/ASR/transducer/test_transducer.py | 6 +++--- egs/yesno/ASR/transducer/train.py | 6 +++--- 8 files changed, 23 insertions(+), 21 deletions(-) rename egs/yesno/ASR/transducer/{jointer.py => joiner.py} (98%) rename egs/yesno/ASR/transducer/{test_jointer.py => test_joiner.py} (83%) diff --git a/.gitignore b/.gitignore index f4f703243..31da5ed3e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ exp exp*/ *.pt download +*.bak +*-bak diff --git a/egs/yesno/ASR/transducer/beam_search.py b/egs/yesno/ASR/transducer/beam_search.py index 1f743823b..ae0f39478 100644 --- a/egs/yesno/ASR/transducer/beam_search.py +++ b/egs/yesno/ASR/transducer/beam_search.py @@ -49,7 +49,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on - logits = model.jointer(current_encoder_out, decoder_out) + logits = model.joiner(current_encoder_out, decoder_out) log_prob = logits.log_softmax(dim=-1) # log_prob is (N, 1, 1) diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py index aa1b214da..abb34da4c 100755 --- a/egs/yesno/ASR/transducer/decode.py +++ b/egs/yesno/ASR/transducer/decode.py @@ -26,7 +26,7 @@ from asr_datamodule import YesNoAsrDataModule from transducer.beam_search import greedy_search from transducer.decoder import Decoder from transducer.encoder import Tdnn -from transducer.jointer import Jointer +from transducer.joiner import Joiner from transducer.model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -248,8 +248,8 @@ def get_transducer_model(params: AttributeDict): embedding_dropout=0.4, rnn_dropout=0.4, ) - jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) return transducer diff --git a/egs/yesno/ASR/transducer/jointer.py b/egs/yesno/ASR/transducer/joiner.py similarity index 98% rename from egs/yesno/ASR/transducer/jointer.py rename to egs/yesno/ASR/transducer/joiner.py index 509554841..0422f8a6f 100644 --- a/egs/yesno/ASR/transducer/jointer.py +++ b/egs/yesno/ASR/transducer/joiner.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F -class Jointer(nn.Module): +class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() diff --git a/egs/yesno/ASR/transducer/model.py b/egs/yesno/ASR/transducer/model.py index bbb6ba7e8..caf9bed37 100644 --- a/egs/yesno/ASR/transducer/model.py +++ b/egs/yesno/ASR/transducer/model.py @@ -41,7 +41,7 @@ class Transducer(nn.Module): self, encoder: nn.Module, decoder: nn.Module, - jointer: nn.Module, + joiner: nn.Module, ): """ Args: @@ -54,7 +54,7 @@ class Transducer(nn.Module): It is the prediction network in the paper. Its input shape is (N, U) and its output shape is (N, U, C). It should contain one attribute: `blank_id`. - jointer: + joiner: It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. @@ -62,7 +62,7 @@ class Transducer(nn.Module): super().__init__() self.encoder = encoder self.decoder = decoder - self.jointer = jointer + self.joiner = joiner def forward( self, @@ -103,7 +103,7 @@ class Transducer(nn.Module): decoder_out, _ = self.decoder(sos_y_padded) - logits = self.jointer(encoder_out, decoder_out) + logits = self.joiner(encoder_out, decoder_out) # rnnt_loss requires 0 padded targets y_padded = y.pad(mode="constant", padding_value=0) diff --git a/egs/yesno/ASR/transducer/test_jointer.py b/egs/yesno/ASR/transducer/test_joiner.py similarity index 83% rename from egs/yesno/ASR/transducer/test_jointer.py rename to egs/yesno/ASR/transducer/test_joiner.py index 6db0270e1..2773ca319 100755 --- a/egs/yesno/ASR/transducer/test_jointer.py +++ b/egs/yesno/ASR/transducer/test_joiner.py @@ -19,31 +19,31 @@ To run this file, do: cd icefall/egs/yesno/ASR - python ./transducer/test_jointer.py + python ./transducer/test_joiner.py """ import torch -from transducer.jointer import Jointer +from transducer.joiner import Joiner -def test_jointer(): +def test_joiner(): N = 2 T = 3 C = 4 U = 5 - jointer = Jointer(C, 10) + joiner = Joiner(C, 10) encoder_out = torch.rand(N, T, C) decoder_out = torch.rand(N, U, C) - joint = jointer(encoder_out, decoder_out) + joint = joiner(encoder_out, decoder_out) assert joint.shape == (N, T, U, 10) def main(): - test_jointer() + test_joiner() if __name__ == "__main__": diff --git a/egs/yesno/ASR/transducer/test_transducer.py b/egs/yesno/ASR/transducer/test_transducer.py index 6aa7e4a53..db7bf9c68 100755 --- a/egs/yesno/ASR/transducer/test_transducer.py +++ b/egs/yesno/ASR/transducer/test_transducer.py @@ -27,7 +27,7 @@ import k2 import torch from transducer.decoder import Decoder from transducer.encoder import Tdnn -from transducer.jointer import Jointer +from transducer.joiner import Joiner from transducer.model import Transducer @@ -54,8 +54,8 @@ def test_transducer(): rnn_dropout=0.0, ) - jointer = Jointer(output_dim, vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer) + joiner = Joiner(output_dim, vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]]) N = y.dim0 diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py index 62172b14e..7d2d1edeb 100755 --- a/egs/yesno/ASR/transducer/train.py +++ b/egs/yesno/ASR/transducer/train.py @@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transducer.decoder import Decoder from transducer.encoder import Tdnn -from transducer.jointer import Jointer +from transducer.joiner import Joiner from transducer.model import Transducer from icefall.checkpoint import load_checkpoint @@ -465,8 +465,8 @@ def get_transducer_model(params: AttributeDict): embedding_dropout=0.4, rnn_dropout=0.4, ) - jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size) - transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) return transducer