mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Refactor transformer.py
This commit is contained in:
parent
1fa30998da
commit
f6091b10c0
@ -89,15 +89,21 @@ class Conformer(Transformer):
|
|||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: Tensor of dimension (batch_size, num_features, input_length).
|
x:
|
||||||
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
|
The model input. Its shape is [N, T, C].
|
||||||
|
supervisions:
|
||||||
|
Supervision in lhotse format.
|
||||||
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
|
CAUTION: It contains length information, i.e., start and number of
|
||||||
|
frames, before subsampling
|
||||||
|
It is read directly from the batch, without any sorting. It is used
|
||||||
|
to compute encoder padding mask, which is used as memory key padding
|
||||||
|
mask for the decoder.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||||
"""
|
"""
|
||||||
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
@ -796,8 +802,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz, num_heads, tgt_len, src_len
|
bsz, num_heads, tgt_len, src_len
|
||||||
)
|
)
|
||||||
attn_output_weights = attn_output_weights.masked_fill(
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
||||||
float("-inf"),
|
|
||||||
)
|
)
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, tgt_len, src_len
|
bsz * num_heads, tgt_len, src_len
|
||||||
@ -867,12 +872,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.BatchNorm1d(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
|
||||||
channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0,
|
|
||||||
bias=bias,
|
|
||||||
)
|
)
|
||||||
self.activation = Swish()
|
self.activation = Swish()
|
||||||
|
|
||||||
|
@ -147,15 +147,10 @@ def decode_one_batch(
|
|||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
# at entry, feature is [N, T, C]
|
# at entry, feature is [N, T, C]
|
||||||
|
|
||||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, C, T]
|
# nnet_output is [N, T, C]
|
||||||
|
|
||||||
nnet_output = nnet_output.permute(0, 2, 1)
|
|
||||||
# now nnet_output is [N, T, C]
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
@ -227,6 +222,8 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=lexicon.sos_id,
|
||||||
|
eos_id=lexicon.eos_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
@ -468,5 +465,8 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
144
egs/librispeech/ASR/conformer_ctc/subsampling.py
Normal file
144
egs/librispeech/ASR/conformer_ctc/subsampling.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
|
Convert an input of shape [N, T, idim] to an output
|
||||||
|
with shape [N, T', odim], where
|
||||||
|
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||||
|
|
||||||
|
It is based on
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim: int, odim: int) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
idim:
|
||||||
|
Input dim. The input shape is [N, T, idim].
|
||||||
|
Caution: It requires: T >=7, idim >=7
|
||||||
|
odim:
|
||||||
|
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||||
|
"""
|
||||||
|
assert idim >= 7
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||||
|
),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is [N, T, idim].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||||
|
"""
|
||||||
|
# On entry, x is [N, T, idim]
|
||||||
|
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
||||||
|
x = self.conv(x)
|
||||||
|
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VggSubsampling(nn.Module):
|
||||||
|
"""Trying to follow the setup described in the following paper:
|
||||||
|
https://arxiv.org/pdf/1910.09799.pdf
|
||||||
|
|
||||||
|
This paper is not 100% explicit so I am guessing to some extent,
|
||||||
|
and trying to compare with other VGG implementations.
|
||||||
|
|
||||||
|
Convert an input of shape [N, T, idim] to an output
|
||||||
|
with shape [N, T', odim], where
|
||||||
|
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, idim: int, odim: int) -> None:
|
||||||
|
"""Construct a VggSubsampling object.
|
||||||
|
|
||||||
|
This uses 2 VGG blocks with 2 Conv2d layers each,
|
||||||
|
subsampling its input by a factor of 4 in the time dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idim:
|
||||||
|
Input dim. The input shape is [N, T, idim].
|
||||||
|
Caution: It requires: T >=7, idim >=7
|
||||||
|
odim:
|
||||||
|
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
cur_channels = 1
|
||||||
|
layers = []
|
||||||
|
block_dims = [32, 64]
|
||||||
|
|
||||||
|
# The decision to use padding=1 for the 1st convolution, then padding=0
|
||||||
|
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
||||||
|
# a back-compatibility concern so that the number of frames at the
|
||||||
|
# output would be equal to:
|
||||||
|
# (((T-1)//2)-1)//2.
|
||||||
|
# We can consider changing this by using padding=1 on the
|
||||||
|
# 2nd convolution, so the num-frames at the output would be T//4.
|
||||||
|
for block_dim in block_dims:
|
||||||
|
layers.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=cur_channels,
|
||||||
|
out_channels=block_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
layers.append(torch.nn.ReLU())
|
||||||
|
layers.append(
|
||||||
|
torch.nn.Conv2d(
|
||||||
|
in_channels=block_dim,
|
||||||
|
out_channels=block_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=0,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
layers.append(
|
||||||
|
torch.nn.MaxPool2d(
|
||||||
|
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cur_channels = block_dim
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.out = nn.Linear(
|
||||||
|
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is [N, T, idim].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||||
|
"""
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = self.layers(x)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
return x
|
33
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Executable file
33
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Executable file
@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
|
from subsampling import VggSubsampling
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv2d_subsampling():
|
||||||
|
N = 3
|
||||||
|
odim = 2
|
||||||
|
|
||||||
|
for T in range(7, 19):
|
||||||
|
for idim in range(7, 20):
|
||||||
|
model = Conv2dSubsampling(idim=idim, odim=odim)
|
||||||
|
x = torch.empty(N, T, idim)
|
||||||
|
y = model(x)
|
||||||
|
assert y.shape[0] == N
|
||||||
|
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||||
|
assert y.shape[2] == odim
|
||||||
|
|
||||||
|
|
||||||
|
def test_vgg_subsampling():
|
||||||
|
N = 3
|
||||||
|
odim = 2
|
||||||
|
|
||||||
|
for T in range(7, 19):
|
||||||
|
for idim in range(7, 20):
|
||||||
|
model = VggSubsampling(idim=idim, odim=odim)
|
||||||
|
x = torch.empty(N, T, idim)
|
||||||
|
y = model(x)
|
||||||
|
assert y.shape[0] == N
|
||||||
|
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||||
|
assert y.shape[2] == odim
|
36
egs/librispeech/ASR/conformer_ctc/test_transformer.py
Normal file
36
egs/librispeech/ASR/conformer_ctc/test_transformer.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformer import Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_padding_mask():
|
||||||
|
supervisions = {
|
||||||
|
"sequence_idx": torch.tensor([0, 1, 2]),
|
||||||
|
"start_frame": torch.tensor([0, 0, 0]),
|
||||||
|
"num_frames": torch.tensor([18, 7, 13]),
|
||||||
|
}
|
||||||
|
|
||||||
|
max_len = ((18 - 1) // 2 - 1) // 2
|
||||||
|
mask = encoder_padding_mask(max_len, supervisions)
|
||||||
|
expected_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
|
||||||
|
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
|
||||||
|
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(torch.eq(mask, expected_mask))
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer():
|
||||||
|
num_features = 40
|
||||||
|
num_classes = 87
|
||||||
|
model = Transformer(num_features=num_features, num_classes=num_classes)
|
||||||
|
|
||||||
|
N = 31
|
||||||
|
|
||||||
|
for T in range(7, 30):
|
||||||
|
x = torch.rand(N, T, num_features)
|
||||||
|
y, _, _ = model(x)
|
||||||
|
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
@ -275,15 +275,13 @@ def compute_loss(
|
|||||||
device = graph_compiler.device
|
device = graph_compiler.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is [N, T, C]
|
# at entry, feature is [N, T, C]
|
||||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, C, T]
|
# nnet_output is [N, T, C]
|
||||||
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
# different duration in decreasing order, required by
|
# different duration in decreasing order, required by
|
||||||
@ -536,6 +534,22 @@ def train_one_epoch(
|
|||||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
)
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_ctc_loss",
|
||||||
|
params.valid_ctc_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_att_loss",
|
||||||
|
params.valid_att_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_loss",
|
||||||
|
params.valid_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
params.train_loss = tot_loss / tot_frames
|
params.train_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
@ -675,5 +689,8 @@ def main():
|
|||||||
run(rank=0, world_size=1, args=args)
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -11,11 +11,18 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
|
||||||
|
# slow things down. Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_librispeech():
|
def compute_fbank_librispeech():
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
@ -46,8 +53,7 @@ def compute_fbank_librispeech():
|
|||||||
continue
|
continue
|
||||||
logging.info(f"Processing {partition}")
|
logging.info(f"Processing {partition}")
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=m["recordings"],
|
recordings=m["recordings"], supervisions=m["supervisions"],
|
||||||
supervisions=m["supervisions"],
|
|
||||||
)
|
)
|
||||||
if "train" in partition:
|
if "train" in partition:
|
||||||
cut_set = (
|
cut_set = (
|
||||||
|
@ -11,11 +11,18 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
|
||||||
|
# slow things down. Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_musan():
|
def compute_fbank_musan():
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
|
@ -555,11 +555,14 @@ def rescore_with_attention_decoder(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
) -> Dict[str, k2.Fsa]:
|
) -> Dict[str, k2.Fsa]:
|
||||||
"""This function extracts n paths from the given lattice and uses
|
"""This function extracts n paths from the given lattice and uses
|
||||||
an attention decoder to rescore them. The path with the highest
|
an attention decoder to rescore them. The path with the highest
|
||||||
score is used as the decoding output.
|
score is used as the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
lattice:
|
lattice:
|
||||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||||
num_paths:
|
num_paths:
|
||||||
@ -573,6 +576,10 @@ def rescore_with_attention_decoder(
|
|||||||
Its shape is `[T, N, C]`.
|
Its shape is `[T, N, C]`.
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask for memory with shape [N, T].
|
The padding mask for memory with shape [N, T].
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of FsaVec, whose key contains a string
|
A dict of FsaVec, whose key contains a string
|
||||||
ngram_lm_scale_attention_scale and the value is the
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
@ -661,7 +668,11 @@ def rescore_with_attention_decoder(
|
|||||||
|
|
||||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||||
nll = model.decoder_nll(
|
nll = model.decoder_nll(
|
||||||
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
|
memory=expanded_memory,
|
||||||
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
assert nll.ndim == 2
|
assert nll.ndim == 2
|
||||||
assert nll.shape[0] == num_word_seqs
|
assert nll.shape[0] == num_word_seqs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user