Refactor transformer.py

This commit is contained in:
Fangjun Kuang 2021-08-02 23:48:26 +08:00
parent 1fa30998da
commit f6091b10c0
10 changed files with 689 additions and 400 deletions

View File

@ -89,15 +89,21 @@ class Conformer(Transformer):
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x: Tensor of dimension (batch_size, num_features, input_length).
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
@ -796,8 +802,7 @@ class RelPositionMultiheadAttention(nn.Module):
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
@ -867,12 +872,7 @@ class ConvolutionModule(nn.Module):
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
)
self.activation = Swish()

View File

@ -147,15 +147,10 @@ def decode_one_batch(
feature = feature.to(device)
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1)
# now nnet_output is [N, T, C]
# nnet_output is [N, T, C]
supervision_segments = torch.stack(
(
@ -227,6 +222,8 @@ def decode_one_batch(
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=lexicon.sos_id,
eos_id=lexicon.eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -468,5 +465,8 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
def test_conv2d_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = Conv2dSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim
def test_vgg_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = VggSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
import torch
from transformer import Transformer, encoder_padding_mask
def test_encoder_padding_mask():
supervisions = {
"sequence_idx": torch.tensor([0, 1, 2]),
"start_frame": torch.tensor([0, 0, 0]),
"num_frames": torch.tensor([18, 7, 13]),
}
max_len = ((18 - 1) // 2 - 1) // 2
mask = encoder_padding_mask(max_len, supervisions)
expected_mask = torch.tensor(
[
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_transformer():
num_features = 40
num_classes = 87
model = Transformer(num_features=num_features, num_classes=num_classes)
N = 31
for T in range(7, 30):
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)

View File

@ -275,15 +275,13 @@ def compute_loss(
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
@ -536,6 +534,22 @@ def train_one_epoch(
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames
@ -675,5 +689,8 @@ def main():
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -11,11 +11,18 @@ import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_librispeech():
src_dir = Path("data/manifests")
@ -46,8 +53,7 @@ def compute_fbank_librispeech():
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
recordings=m["recordings"], supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (

View File

@ -11,11 +11,18 @@ import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan():
src_dir = Path("data/manifests")

View File

@ -555,11 +555,14 @@ def rescore_with_attention_decoder(
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sos_id: int,
eos_id: int,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
@ -573,6 +576,10 @@ def rescore_with_attention_decoder(
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
@ -661,7 +668,11 @@ def rescore_with_attention_decoder(
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs