Remove unused code.

This commit is contained in:
Fangjun Kuang 2021-08-03 17:24:06 +08:00
parent f6091b10c0
commit 2be7a0a555
4 changed files with 166 additions and 227 deletions

View File

@ -84,7 +84,7 @@ class Conformer(Transformer):
# and throws an error without this change.
self.after_norm = identity
def encode(
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
@ -802,7 +802,8 @@ class RelPositionMultiheadAttention(nn.Module):
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
@ -872,7 +873,12 @@ class ConvolutionModule(nn.Module):
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = Swish()

View File

@ -15,6 +15,7 @@ import torch
import torch.nn as nn
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "1best",
"method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 100,
@ -100,6 +101,8 @@ def decode_one_batch(
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -133,6 +136,10 @@ def decode_one_batch(
for the format of the `batch`.
lexicon:
It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -222,8 +229,8 @@ def decode_one_batch(
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=lexicon.sos_id,
eos_id=lexicon.eos_id,
sos_id=sos_id,
eos_id=eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -242,6 +249,8 @@ def decode_dataset(
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
@ -257,6 +266,10 @@ def decode_dataset(
The decoding graph.
lexicon:
It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -284,6 +297,8 @@ def decode_dataset(
batch=batch,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
for lm_scale, hyps in hyps_dict.items():
@ -364,6 +379,15 @@ def main():
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -456,6 +480,8 @@ def main():
HLG=HLG,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(

View File

@ -1,7 +1,16 @@
#!/usr/bin/env python3
import torch
from transformer import Transformer, encoder_padding_mask
from transformer import (
Transformer,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
@ -34,3 +43,47 @@ def test_transformer():
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s)
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_decoder_padding_mask():
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor(
[
[False, False, True],
[False, True, True],
[False, False, False],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_add_sos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_sos(x, sos_id=0)
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
assert y == expected_y
def test_add_eos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_eos(x, eos_id=0)
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
assert y == expected_y

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import get_texts
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]
@ -177,14 +178,17 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
encoder_memory, memory_key_padding_mask = self.encode(x, supervision)
x = self.encoder_output(encoder_memory)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
def encode(
def run_encoder(
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
"""Run the transformer encoder.
Args:
x:
The model input. Its shape is [N, T, C].
@ -194,8 +198,8 @@ class Transformer(nn.Module):
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
to compute the encoder padding mask, which is used as memory key
padding mask for the decoder.
Returns:
Return a tuple with two tensors:
- The encoder output, with shape [T, N, C]
@ -212,7 +216,7 @@ class Transformer(nn.Module):
return x, mask
def encoder_output(self, x: torch.Tensor) -> torch.Tensor:
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
@ -232,46 +236,16 @@ class Transformer(nn.Module):
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
supervision: Optional[Supervisions] = None,
L_inv: Optional[k2.Fsa] = None,
word_table: Optional[k2.SymbolTable] = None,
oov_str: Optional[str] = None,
token_ids: List[List[int]] = None,
sos_id: Optional[int] = None,
eos_id: Optional[int] = None,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Note:
If phone based lexicon is used, the following arguments are required:
- supervision
- L_inv
- word_table
- oov_str
If BPE based lexicon is used, the following arguments are required:
- token_ids
- sos_id
- eos_id
Args:
memory:
It's the output of the encoder with shape [T, N, C]
memory_key_padding_mask:
The padding mask from the encoder.
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
L_inv:
It is an FSA with labels being word IDs and aux_labels being
token IDs (e.g., phone IDs or word piece IDs).
word_table:
Word table providing mapping between words and IDs.
oov_str:
The OOV word, e.g., '<UNK>'
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs.
@ -284,29 +258,13 @@ class Transformer(nn.Module):
A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
if supervision is not None and word_table is not None:
batch_text = get_normal_transcripts(
supervision, word_table, oov_str
)
ys_in_pad, ys_out_pad = add_sos_eos(
batch_text,
L_inv,
sos_id,
eos_id,
)
elif token_ids is not None:
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
]
ys_out = [
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
]
ys_in_pad = pad_list(ys_in, eos_id)
ys_out_pad = pad_list(ys_out, -1)
else:
raise ValueError("Invalid input for decoder self attention")
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
@ -316,6 +274,8 @@ class Transformer(nn.Module):
device
)
# TODO: Use eos_id as ignore_id.
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
@ -362,19 +322,14 @@ class Transformer(nn.Module):
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
if token_ids is not None:
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
]
ys_out = [
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
]
ys_in_pad = pad_list(ys_in, eos_id)
ys_out_pad = pad_list(ys_out, -1)
else:
raise ValueError("Invalid input for decoder self attention")
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -384,6 +339,8 @@ class Transformer(nn.Module):
device
)
# TODO: Use eos_id as ignore_id.
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
@ -948,8 +905,8 @@ def decoder_padding_mask(
) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with bool(True),
Unmasked positions are filled with bool(False).
The masked position are filled with True,
Unmasked positions are filled with False.
Args:
ys_pad:
@ -965,45 +922,16 @@ def decoder_padding_mask(
return ys_mask
def get_normal_transcripts(
supervision: Supervisions, words: k2.SymbolTable, oov: str = "<UNK>"
) -> List[List[int]]:
"""Get normal transcripts (1 input recording has 1 transcript)
from lhotse cut format.
Achieved by concatenating the transcripts corresponding to the
same recording.
Args:
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
words:
The word symbol table.
oov:
Out of vocabulary word.
Returns:
List[List[int]]: List of concatenated transcripts, length is batch_size
"""
texts = [
[token if token in words else oov for token in text.split(" ")]
for text in supervision["text"]
]
texts_ids = [[words[token] for token in text] for text in texts]
batch_text = [
[] for _ in range(int(supervision["sequence_idx"].max().item()) + 1)
]
for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids):
batch_text[sequence_idx] = batch_text[sequence_idx] + text
return batch_text
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args:
sz: mask size
@ -1020,115 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
return mask
def add_sos_eos(
ys: List[List[int]],
L_inv: k2.Fsa,
sos_id: int,
eos_id: int,
ignore_id: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
"""Prepend sos_id to each utterance.
Args:
ys:
Batch of unpadded target sequences (i.e., word IDs)
L_inv:
Its labels are words, while its aux_labels are tokens.
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
sos_id:
index of <sos>
The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
ans = []
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id:
index of <eos>
ignore_id:
value for padding
The ID of the EOS token.
Returns:
Return a tuple containing two tensors:
- Input of transformer decoder.
Padded tensor of dimension (batch_size, max_length).
- Output of transformer decoder.
Padded tensor of dimension (batch_size, max_length).
Return:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys = get_hierarchical_targets(ys, L_inv)
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id)
def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor:
"""Perform padding for the list of tensors.
Args:
ys:
List of tensors. len(ys) = batch_size.
pad_value:
Value for padding.
Returns:
Tensor: Padded tensor (batch_size, max_length, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(ys)
max_len = max(x.size(0) for x in ys)
pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value)
for i in range(n_batch):
pad[i, : ys[i].size(0)] = ys[i]
return pad
def get_hierarchical_targets(
ys: List[List[int]], L_inv: Optional[k2.Fsa] = None
) -> List[torch.Tensor]:
"""Get hierarchical transcripts (i.e., phone level transcripts) from
transcripts (i.e., word level transcripts).
Args:
ys:
Word level transcripts. Each sublist is a transcript of an utterance.
L_inv:
Its labels are words, while its aux_labels are tokens.
Returns:
List[torch.Tensor]:
Token level transcripts.
"""
if L_inv is None:
return [torch.tensor(y) for y in ys]
device = L_inv.device
transcripts = k2.create_fsa_vec(
[k2.linear_fsa(x, device=device) for x in ys]
)
transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)
transcripts_lexicon = k2.intersect(
L_inv, transcripts_with_self_loops, treat_epsilons_specially=False
)
# Don't call invert_() above because we want to return phone IDs,
# which is the `aux_labels` of transcripts_lexicon
transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
transcripts_lexicon = k2.top_sort(transcripts_lexicon)
transcripts_lexicon = k2.shortest_path(
transcripts_lexicon, use_double_scores=True
)
ys = get_texts(transcripts_lexicon)
ys = [torch.tensor(y) for y in ys]
return ys
ans = []
for utt in token_ids:
ans.append(utt + [eos_id])
return ans