Remove unused code.

This commit is contained in:
Fangjun Kuang 2021-08-03 17:24:06 +08:00
parent f6091b10c0
commit 2be7a0a555
4 changed files with 166 additions and 227 deletions

View File

@ -84,7 +84,7 @@ class Conformer(Transformer):
# and throws an error without this change. # and throws an error without this change.
self.after_norm = identity self.after_norm = identity
def encode( def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
""" """
@ -802,7 +802,8 @@ class RelPositionMultiheadAttention(nn.Module):
bsz, num_heads, tgt_len, src_len bsz, num_heads, tgt_len, src_len
) )
attn_output_weights = attn_output_weights.masked_fill( attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
) )
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len bsz * num_heads, tgt_len, src_len
@ -872,7 +873,12 @@ class ConvolutionModule(nn.Module):
) )
self.norm = nn.BatchNorm1d(channels) self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
) )
self.activation = Swish() self.activation = Swish()

View File

@ -15,6 +15,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import ( from icefall.decode import (
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring # - whole-lattice-rescoring
# - attention-decoder # - attention-decoder
# "method": "whole-lattice-rescoring", # "method": "whole-lattice-rescoring",
"method": "1best", "method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring", # num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder # and attention-decoder
"num_paths": 100, "num_paths": 100,
@ -100,6 +101,8 @@ def decode_one_batch(
HLG: k2.Fsa, HLG: k2.Fsa,
batch: dict, batch: dict,
lexicon: Lexicon, lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -133,6 +136,10 @@ def decode_one_batch(
for the format of the `batch`. for the format of the `batch`.
lexicon: lexicon:
It contains word symbol table. It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G: G:
An LM. It is not None when params.method is "nbest-rescoring" An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG or "whole-lattice-rescoring". In general, the G in HLG
@ -222,8 +229,8 @@ def decode_one_batch(
model=model, model=model,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=lexicon.sos_id, sos_id=sos_id,
eos_id=lexicon.eos_id, eos_id=eos_id,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -242,6 +249,8 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: k2.Fsa,
lexicon: Lexicon, lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset. """Decode dataset.
@ -257,6 +266,10 @@ def decode_dataset(
The decoding graph. The decoding graph.
lexicon: lexicon:
It contains word symbol table. It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G: G:
An LM. It is not None when params.method is "nbest-rescoring" An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG or "whole-lattice-rescoring". In general, the G in HLG
@ -284,6 +297,8 @@ def decode_dataset(
batch=batch, batch=batch,
lexicon=lexicon, lexicon=lexicon,
G=G, G=G,
sos_id=sos_id,
eos_id=eos_id,
) )
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
@ -364,6 +379,15 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -456,6 +480,8 @@ def main():
HLG=HLG, HLG=HLG,
lexicon=lexicon, lexicon=lexicon,
G=G, G=G,
sos_id=sos_id,
eos_id=eos_id,
) )
save_results( save_results(

View File

@ -1,7 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import torch import torch
from transformer import Transformer, encoder_padding_mask from transformer import (
Transformer,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask(): def test_encoder_padding_mask():
@ -34,3 +43,47 @@ def test_transformer():
x = torch.rand(N, T, num_features) x = torch.rand(N, T, num_features)
y, _, _ = model(x) y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s)
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_decoder_padding_mask():
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor(
[
[False, False, True],
[False, True, True],
[False, False, False],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_add_sos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_sos(x, sos_id=0)
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
assert y == expected_y
def test_add_eos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_eos(x, eos_id=0)
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
assert y == expected_y

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import get_texts from icefall.utils import get_texts
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed. # Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor] Supervisions = Dict[str, torch.Tensor]
@ -177,14 +178,17 @@ class Transformer(nn.Module):
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
x = self.feat_batchnorm(x) x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
encoder_memory, memory_key_padding_mask = self.encode(x, supervision) encoder_memory, memory_key_padding_mask = self.run_encoder(
x = self.encoder_output(encoder_memory) x, supervision
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask return x, encoder_memory, memory_key_padding_mask
def encode( def run_encoder(
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
""" """Run the transformer encoder.
Args: Args:
x: x:
The model input. Its shape is [N, T, C]. The model input. Its shape is [N, T, C].
@ -194,8 +198,8 @@ class Transformer(nn.Module):
CAUTION: It contains length information, i.e., start and number of CAUTION: It contains length information, i.e., start and number of
frames, before subsampling frames, before subsampling
It is read directly from the batch, without any sorting. It is used It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding to compute the encoder padding mask, which is used as memory key
mask for the decoder. padding mask for the decoder.
Returns: Returns:
Return a tuple with two tensors: Return a tuple with two tensors:
- The encoder output, with shape [T, N, C] - The encoder output, with shape [T, N, C]
@ -212,7 +216,7 @@ class Transformer(nn.Module):
return x, mask return x, mask
def encoder_output(self, x: torch.Tensor) -> torch.Tensor: def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
x: x:
@ -232,46 +236,16 @@ class Transformer(nn.Module):
self, self,
memory: torch.Tensor, memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
supervision: Optional[Supervisions] = None, token_ids: List[List[int]],
L_inv: Optional[k2.Fsa] = None, sos_id: int,
word_table: Optional[k2.SymbolTable] = None, eos_id: int,
oov_str: Optional[str] = None,
token_ids: List[List[int]] = None,
sos_id: Optional[int] = None,
eos_id: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Note:
If phone based lexicon is used, the following arguments are required:
- supervision
- L_inv
- word_table
- oov_str
If BPE based lexicon is used, the following arguments are required:
- token_ids
- sos_id
- eos_id
Args: Args:
memory: memory:
It's the output of the encoder with shape [T, N, C] It's the output of the encoder with shape [T, N, C]
memory_key_padding_mask: memory_key_padding_mask:
The padding mask from the encoder. The padding mask from the encoder.
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
L_inv:
It is an FSA with labels being word IDs and aux_labels being
token IDs (e.g., phone IDs or word piece IDs).
word_table:
Word table providing mapping between words and IDs.
oov_str:
The OOV word, e.g., '<UNK>'
token_ids: token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance. A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs. The IDs can be either phone IDs or word piece IDs.
@ -284,29 +258,13 @@ class Transformer(nn.Module):
A scalar, the **sum** of label smoothing loss over utterances A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization. in the batch without any normalization.
""" """
if supervision is not None and word_table is not None: ys_in = add_sos(token_ids, sos_id=sos_id)
batch_text = get_normal_transcripts( ys_in = [torch.tensor(y) for y in ys_in]
supervision, word_table, oov_str ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
)
ys_in_pad, ys_out_pad = add_sos_eos( ys_out = add_eos(token_ids, eos_id=eos_id)
batch_text, ys_out = [torch.tensor(y) for y in ys_out]
L_inv, ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
sos_id,
eos_id,
)
elif token_ids is not None:
_sos = torch.tensor([sos_id])
_eos = torch.tensor([eos_id])
ys_in = [
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
]
ys_out = [
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
]
ys_in_pad = pad_list(ys_in, eos_id)
ys_out_pad = pad_list(ys_out, -1)
else:
raise ValueError("Invalid input for decoder self attention")
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device) ys_in_pad = ys_in_pad.to(device)
@ -316,6 +274,8 @@ class Transformer(nn.Module):
device device
) )
# TODO: Use eos_id as ignore_id.
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
@ -362,19 +322,14 @@ class Transformer(nn.Module):
""" """
# The common part between this function and decoder_forward could be # The common part between this function and decoder_forward could be
# extracted as a separate function. # extracted as a separate function.
if token_ids is not None:
_sos = torch.tensor([sos_id]) ys_in = add_sos(token_ids, sos_id=sos_id)
_eos = torch.tensor([eos_id]) ys_in = [torch.tensor(y) for y in ys_in]
ys_in = [ ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
] ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [ ys_out = [torch.tensor(y) for y in ys_out]
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
]
ys_in_pad = pad_list(ys_in, eos_id)
ys_out_pad = pad_list(ys_out, -1)
else:
raise ValueError("Invalid input for decoder self attention")
device = memory.device device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
@ -384,6 +339,8 @@ class Transformer(nn.Module):
device device
) )
# TODO: Use eos_id as ignore_id.
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
@ -948,8 +905,8 @@ def decoder_padding_mask(
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate a length mask for input. """Generate a length mask for input.
The masked position are filled with bool(True), The masked position are filled with True,
Unmasked positions are filled with bool(False). Unmasked positions are filled with False.
Args: Args:
ys_pad: ys_pad:
@ -965,45 +922,16 @@ def decoder_padding_mask(
return ys_mask return ys_mask
def get_normal_transcripts(
supervision: Supervisions, words: k2.SymbolTable, oov: str = "<UNK>"
) -> List[List[int]]:
"""Get normal transcripts (1 input recording has 1 transcript)
from lhotse cut format.
Achieved by concatenating the transcripts corresponding to the
same recording.
Args:
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
words:
The word symbol table.
oov:
Out of vocabulary word.
Returns:
List[List[int]]: List of concatenated transcripts, length is batch_size
"""
texts = [
[token if token in words else oov for token in text.split(" ")]
for text in supervision["text"]
]
texts_ids = [[words[token] for token in text] for text in texts]
batch_text = [
[] for _ in range(int(supervision["sequence_idx"].max().item()) + 1)
]
for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids):
batch_text[sequence_idx] = batch_text[sequence_idx] + text
return batch_text
def generate_square_subsequent_mask(sz: int) -> torch.Tensor: def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are """Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0). filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args: Args:
sz: mask size sz: mask size
@ -1020,115 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
return mask return mask
def add_sos_eos( def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
ys: List[List[int]], """Prepend sos_id to each utterance.
L_inv: k2.Fsa,
sos_id: int,
eos_id: int,
ignore_id: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
Args: Args:
ys: token_ids:
Batch of unpadded target sequences (i.e., word IDs) A list-of-list of token IDs. Each sublist contains
L_inv: token IDs (e.g., word piece IDs) of an utterance.
Its labels are words, while its aux_labels are tokens.
sos_id: sos_id:
index of <sos> The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
ans = []
for utt in token_ids:
ans.append([sos_id] + utt)
return ans
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id: eos_id:
index of <eos> The ID of the EOS token.
ignore_id:
value for padding
Returns: Return:
Return a tuple containing two tensors: Return a new list-of-list, where each sublist ends
- Input of transformer decoder. with EOS ID.
Padded tensor of dimension (batch_size, max_length).
- Output of transformer decoder.
Padded tensor of dimension (batch_size, max_length).
""" """
ans = []
_sos = torch.tensor([sos_id]) for utt in token_ids:
_eos = torch.tensor([eos_id]) ans.append(utt + [eos_id])
ys = get_hierarchical_targets(ys, L_inv) return ans
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id)
def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor:
"""Perform padding for the list of tensors.
Args:
ys:
List of tensors. len(ys) = batch_size.
pad_value:
Value for padding.
Returns:
Tensor: Padded tensor (batch_size, max_length, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(ys)
max_len = max(x.size(0) for x in ys)
pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value)
for i in range(n_batch):
pad[i, : ys[i].size(0)] = ys[i]
return pad
def get_hierarchical_targets(
ys: List[List[int]], L_inv: Optional[k2.Fsa] = None
) -> List[torch.Tensor]:
"""Get hierarchical transcripts (i.e., phone level transcripts) from
transcripts (i.e., word level transcripts).
Args:
ys:
Word level transcripts. Each sublist is a transcript of an utterance.
L_inv:
Its labels are words, while its aux_labels are tokens.
Returns:
List[torch.Tensor]:
Token level transcripts.
"""
if L_inv is None:
return [torch.tensor(y) for y in ys]
device = L_inv.device
transcripts = k2.create_fsa_vec(
[k2.linear_fsa(x, device=device) for x in ys]
)
transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)
transcripts_lexicon = k2.intersect(
L_inv, transcripts_with_self_loops, treat_epsilons_specially=False
)
# Don't call invert_() above because we want to return phone IDs,
# which is the `aux_labels` of transcripts_lexicon
transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
transcripts_lexicon = k2.top_sort(transcripts_lexicon)
transcripts_lexicon = k2.shortest_path(
transcripts_lexicon, use_double_scores=True
)
ys = get_texts(transcripts_lexicon)
ys = [torch.tensor(y) for y in ys]
return ys