mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Rename Jointer to Joiner.
This commit is contained in:
parent
8038d13ec5
commit
b86f45e217
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,3 +6,5 @@ exp
|
||||
exp*/
|
||||
*.pt
|
||||
download
|
||||
*.bak
|
||||
*-bak
|
||||
|
@ -49,7 +49,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# fmt: on
|
||||
logits = model.jointer(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (N, 1, 1)
|
||||
|
@ -26,7 +26,7 @@ from asr_datamodule import YesNoAsrDataModule
|
||||
from transducer.beam_search import greedy_search
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.encoder import Tdnn
|
||||
from transducer.jointer import Jointer
|
||||
from transducer.joiner import Joiner
|
||||
from transducer.model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
@ -248,8 +248,8 @@ def get_transducer_model(params: AttributeDict):
|
||||
embedding_dropout=0.4,
|
||||
rnn_dropout=0.4,
|
||||
)
|
||||
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
|
||||
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||
return transducer
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Jointer(nn.Module):
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
@ -41,7 +41,7 @@ class Transducer(nn.Module):
|
||||
self,
|
||||
encoder: nn.Module,
|
||||
decoder: nn.Module,
|
||||
jointer: nn.Module,
|
||||
joiner: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -54,7 +54,7 @@ class Transducer(nn.Module):
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
one attribute: `blank_id`.
|
||||
jointer:
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
@ -62,7 +62,7 @@ class Transducer(nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.jointer = jointer
|
||||
self.joiner = joiner
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -103,7 +103,7 @@ class Transducer(nn.Module):
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.jointer(encoder_out, decoder_out)
|
||||
logits = self.joiner(encoder_out, decoder_out)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
@ -19,31 +19,31 @@
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/yesno/ASR
|
||||
python ./transducer/test_jointer.py
|
||||
python ./transducer/test_joiner.py
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from transducer.jointer import Jointer
|
||||
from transducer.joiner import Joiner
|
||||
|
||||
|
||||
def test_jointer():
|
||||
def test_joiner():
|
||||
N = 2
|
||||
T = 3
|
||||
C = 4
|
||||
U = 5
|
||||
|
||||
jointer = Jointer(C, 10)
|
||||
joiner = Joiner(C, 10)
|
||||
|
||||
encoder_out = torch.rand(N, T, C)
|
||||
decoder_out = torch.rand(N, U, C)
|
||||
|
||||
joint = jointer(encoder_out, decoder_out)
|
||||
joint = joiner(encoder_out, decoder_out)
|
||||
assert joint.shape == (N, T, U, 10)
|
||||
|
||||
|
||||
def main():
|
||||
test_jointer()
|
||||
test_joiner()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
@ -27,7 +27,7 @@ import k2
|
||||
import torch
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.encoder import Tdnn
|
||||
from transducer.jointer import Jointer
|
||||
from transducer.joiner import Joiner
|
||||
from transducer.model import Transducer
|
||||
|
||||
|
||||
@ -54,8 +54,8 @@ def test_transducer():
|
||||
rnn_dropout=0.0,
|
||||
)
|
||||
|
||||
jointer = Jointer(output_dim, vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
|
||||
joiner = Joiner(output_dim, vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||
|
||||
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
|
||||
N = y.dim0
|
||||
|
@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.encoder import Tdnn
|
||||
from transducer.jointer import Jointer
|
||||
from transducer.joiner import Joiner
|
||||
from transducer.model import Transducer
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -465,8 +465,8 @@ def get_transducer_model(params: AttributeDict):
|
||||
embedding_dropout=0.4,
|
||||
rnn_dropout=0.4,
|
||||
)
|
||||
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
|
||||
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
||||
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||
|
||||
return transducer
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user