mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Rename Jointer to Joiner.
This commit is contained in:
parent
8038d13ec5
commit
b86f45e217
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,3 +6,5 @@ exp
|
|||||||
exp*/
|
exp*/
|
||||||
*.pt
|
*.pt
|
||||||
download
|
download
|
||||||
|
*.bak
|
||||||
|
*-bak
|
||||||
|
@ -49,7 +49,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
logits = model.jointer(current_encoder_out, decoder_out)
|
logits = model.joiner(current_encoder_out, decoder_out)
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
# log_prob is (N, 1, 1)
|
# log_prob is (N, 1, 1)
|
||||||
|
@ -26,7 +26,7 @@ from asr_datamodule import YesNoAsrDataModule
|
|||||||
from transducer.beam_search import greedy_search
|
from transducer.beam_search import greedy_search
|
||||||
from transducer.decoder import Decoder
|
from transducer.decoder import Decoder
|
||||||
from transducer.encoder import Tdnn
|
from transducer.encoder import Tdnn
|
||||||
from transducer.jointer import Jointer
|
from transducer.joiner import Joiner
|
||||||
from transducer.model import Transducer
|
from transducer.model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -248,8 +248,8 @@ def get_transducer_model(params: AttributeDict):
|
|||||||
embedding_dropout=0.4,
|
embedding_dropout=0.4,
|
||||||
rnn_dropout=0.4,
|
rnn_dropout=0.4,
|
||||||
)
|
)
|
||||||
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
||||||
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
|
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||||
return transducer
|
return transducer
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class Jointer(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -41,7 +41,7 @@ class Transducer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder: nn.Module,
|
encoder: nn.Module,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
jointer: nn.Module,
|
joiner: nn.Module,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -54,7 +54,7 @@ class Transducer(nn.Module):
|
|||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, C). It should contain
|
is (N, U) and its output shape is (N, U, C). It should contain
|
||||||
one attribute: `blank_id`.
|
one attribute: `blank_id`.
|
||||||
jointer:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, C). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
@ -62,7 +62,7 @@ class Transducer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.jointer = jointer
|
self.joiner = joiner
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -103,7 +103,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
decoder_out, _ = self.decoder(sos_y_padded)
|
decoder_out, _ = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
logits = self.jointer(encoder_out, decoder_out)
|
logits = self.joiner(encoder_out, decoder_out)
|
||||||
|
|
||||||
# rnnt_loss requires 0 padded targets
|
# rnnt_loss requires 0 padded targets
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
@ -19,31 +19,31 @@
|
|||||||
To run this file, do:
|
To run this file, do:
|
||||||
|
|
||||||
cd icefall/egs/yesno/ASR
|
cd icefall/egs/yesno/ASR
|
||||||
python ./transducer/test_jointer.py
|
python ./transducer/test_joiner.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transducer.jointer import Jointer
|
from transducer.joiner import Joiner
|
||||||
|
|
||||||
|
|
||||||
def test_jointer():
|
def test_joiner():
|
||||||
N = 2
|
N = 2
|
||||||
T = 3
|
T = 3
|
||||||
C = 4
|
C = 4
|
||||||
U = 5
|
U = 5
|
||||||
|
|
||||||
jointer = Jointer(C, 10)
|
joiner = Joiner(C, 10)
|
||||||
|
|
||||||
encoder_out = torch.rand(N, T, C)
|
encoder_out = torch.rand(N, T, C)
|
||||||
decoder_out = torch.rand(N, U, C)
|
decoder_out = torch.rand(N, U, C)
|
||||||
|
|
||||||
joint = jointer(encoder_out, decoder_out)
|
joint = joiner(encoder_out, decoder_out)
|
||||||
assert joint.shape == (N, T, U, 10)
|
assert joint.shape == (N, T, U, 10)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
test_jointer()
|
test_joiner()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
@ -27,7 +27,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
from transducer.decoder import Decoder
|
from transducer.decoder import Decoder
|
||||||
from transducer.encoder import Tdnn
|
from transducer.encoder import Tdnn
|
||||||
from transducer.jointer import Jointer
|
from transducer.joiner import Joiner
|
||||||
from transducer.model import Transducer
|
from transducer.model import Transducer
|
||||||
|
|
||||||
|
|
||||||
@ -54,8 +54,8 @@ def test_transducer():
|
|||||||
rnn_dropout=0.0,
|
rnn_dropout=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
jointer = Jointer(output_dim, vocab_size)
|
joiner = Joiner(output_dim, vocab_size)
|
||||||
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
|
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||||
|
|
||||||
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
|
y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]])
|
||||||
N = y.dim0
|
N = y.dim0
|
||||||
|
@ -34,7 +34,7 @@ from torch.nn.utils import clip_grad_norm_
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transducer.decoder import Decoder
|
from transducer.decoder import Decoder
|
||||||
from transducer.encoder import Tdnn
|
from transducer.encoder import Tdnn
|
||||||
from transducer.jointer import Jointer
|
from transducer.joiner import Joiner
|
||||||
from transducer.model import Transducer
|
from transducer.model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
@ -465,8 +465,8 @@ def get_transducer_model(params: AttributeDict):
|
|||||||
embedding_dropout=0.4,
|
embedding_dropout=0.4,
|
||||||
rnn_dropout=0.4,
|
rnn_dropout=0.4,
|
||||||
)
|
)
|
||||||
jointer = Jointer(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size)
|
||||||
transducer = Transducer(encoder=encoder, decoder=decoder, jointer=jointer)
|
transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner)
|
||||||
|
|
||||||
return transducer
|
return transducer
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user