mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Refactor transformer.py
This commit is contained in:
parent
1fa30998da
commit
f6091b10c0
@ -89,15 +89,21 @@ class Conformer(Transformer):
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x: Tensor of dimension (batch_size, num_features, input_length).
|
||||
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
@ -796,8 +802,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
@ -867,12 +872,7 @@ class ConvolutionModule(nn.Module):
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
|
||||
)
|
||||
self.activation = Swish()
|
||||
|
||||
|
@ -147,15 +147,10 @@ def decode_one_batch(
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, C, T]
|
||||
|
||||
nnet_output = nnet_output.permute(0, 2, 1)
|
||||
# now nnet_output is [N, T, C]
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -227,6 +222,8 @@ def decode_one_batch(
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=lexicon.sos_id,
|
||||
eos_id=lexicon.eos_id,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
@ -468,5 +465,8 @@ def main():
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
144
egs/librispeech/ASR/conformer_ctc/subsampling.py
Normal file
144
egs/librispeech/ASR/conformer_ctc/subsampling.py
Normal file
@ -0,0 +1,144 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
# On entry, x is [N, T, idim]
|
||||
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
||||
x = self.conv(x)
|
||||
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
||||
return x
|
||||
|
||||
|
||||
class VggSubsampling(nn.Module):
|
||||
"""Trying to follow the setup described in the following paper:
|
||||
https://arxiv.org/pdf/1910.09799.pdf
|
||||
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""Construct a VggSubsampling object.
|
||||
|
||||
This uses 2 VGG blocks with 2 Conv2d layers each,
|
||||
subsampling its input by a factor of 4 in the time dimensions.
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
cur_channels = 1
|
||||
layers = []
|
||||
block_dims = [32, 64]
|
||||
|
||||
# The decision to use padding=1 for the 1st convolution, then padding=0
|
||||
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
||||
# a back-compatibility concern so that the number of frames at the
|
||||
# output would be equal to:
|
||||
# (((T-1)//2)-1)//2.
|
||||
# We can consider changing this by using padding=1 on the
|
||||
# 2nd convolution, so the num-frames at the output would be T//4.
|
||||
for block_dim in block_dims:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=cur_channels,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(torch.nn.ReLU())
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=block_dim,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=0,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||
)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
return x
|
33
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Executable file
33
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Executable file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from subsampling import Conv2dSubsampling
|
||||
from subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = Conv2dSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
||||
|
||||
|
||||
def test_vgg_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = VggSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
36
egs/librispeech/ASR/conformer_ctc/test_transformer.py
Normal file
36
egs/librispeech/ASR/conformer_ctc/test_transformer.py
Normal file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from transformer import Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
supervisions = {
|
||||
"sequence_idx": torch.tensor([0, 1, 2]),
|
||||
"start_frame": torch.tensor([0, 0, 0]),
|
||||
"num_frames": torch.tensor([18, 7, 13]),
|
||||
}
|
||||
|
||||
max_len = ((18 - 1) // 2 - 1) // 2
|
||||
mask = encoder_padding_mask(max_len, supervisions)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
|
||||
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
|
||||
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_transformer():
|
||||
num_features = 40
|
||||
num_classes = 87
|
||||
model = Transformer(num_features=num_features, num_classes=num_classes)
|
||||
|
||||
N = 31
|
||||
|
||||
for T in range(7, 30):
|
||||
x = torch.rand(N, T, num_features)
|
||||
y, _, _ = model(x)
|
||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
@ -275,15 +275,13 @@ def compute_loss(
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is [N, T, C]
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, C, T]
|
||||
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
@ -536,6 +534,22 @@ def train_one_epoch(
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_ctc_loss",
|
||||
params.valid_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_att_loss",
|
||||
params.valid_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = tot_loss / tot_frames
|
||||
|
||||
@ -675,5 +689,8 @@ def main():
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,11 +11,18 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
|
||||
# slow things down. Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_librispeech():
|
||||
src_dir = Path("data/manifests")
|
||||
@ -46,8 +53,7 @@ def compute_fbank_librispeech():
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
recordings=m["recordings"], supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
|
@ -11,11 +11,18 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
|
||||
# slow things down. Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_musan():
|
||||
src_dir = Path("data/manifests")
|
||||
|
@ -555,11 +555,14 @@ def rescore_with_attention_decoder(
|
||||
model: nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts n paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest
|
||||
score is used as the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
num_paths:
|
||||
@ -573,6 +576,10 @@ def rescore_with_attention_decoder(
|
||||
Its shape is `[T, N, C]`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape [N, T].
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
@ -661,7 +668,11 @@ def rescore_with_attention_decoder(
|
||||
|
||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||
nll = model.decoder_nll(
|
||||
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == num_word_seqs
|
||||
|
Loading…
x
Reference in New Issue
Block a user