mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Remove unused code.
This commit is contained in:
parent
f6091b10c0
commit
2be7a0a555
@ -84,7 +84,7 @@ class Conformer(Transformer):
|
|||||||
# and throws an error without this change.
|
# and throws an error without this change.
|
||||||
self.after_norm = identity
|
self.after_norm = identity
|
||||||
|
|
||||||
def encode(
|
def run_encoder(
|
||||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""
|
"""
|
||||||
@ -802,7 +802,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
bsz, num_heads, tgt_len, src_len
|
bsz, num_heads, tgt_len, src_len
|
||||||
)
|
)
|
||||||
attn_output_weights = attn_output_weights.masked_fill(
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||||
|
float("-inf"),
|
||||||
)
|
)
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, tgt_len, src_len
|
bsz * num_heads, tgt_len, src_len
|
||||||
@ -872,7 +873,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.BatchNorm1d(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.activation = Swish()
|
self.activation = Swish()
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
|
|||||||
# - whole-lattice-rescoring
|
# - whole-lattice-rescoring
|
||||||
# - attention-decoder
|
# - attention-decoder
|
||||||
# "method": "whole-lattice-rescoring",
|
# "method": "whole-lattice-rescoring",
|
||||||
"method": "1best",
|
"method": "attention-decoder",
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
# and attention-decoder
|
# and attention-decoder
|
||||||
"num_paths": 100,
|
"num_paths": 100,
|
||||||
@ -100,6 +101,8 @@ def decode_one_batch(
|
|||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -133,6 +136,10 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
lexicon:
|
lexicon:
|
||||||
It contains word symbol table.
|
It contains word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID of the SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID of the EOS.
|
||||||
G:
|
G:
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
@ -222,8 +229,8 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=lexicon.sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=lexicon.eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported decoding method: {params.method}"
|
assert False, f"Unsupported decoding method: {params.method}"
|
||||||
@ -242,6 +249,8 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -257,6 +266,10 @@ def decode_dataset(
|
|||||||
The decoding graph.
|
The decoding graph.
|
||||||
lexicon:
|
lexicon:
|
||||||
It contains word symbol table.
|
It contains word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
G:
|
G:
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
@ -284,6 +297,8 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
@ -364,6 +379,15 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
|
params.lang_dir,
|
||||||
|
device=device,
|
||||||
|
sos_token="<sos/eos>",
|
||||||
|
eos_token="<sos/eos>",
|
||||||
|
)
|
||||||
|
sos_id = graph_compiler.sos_id
|
||||||
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
@ -456,6 +480,8 @@ def main():
|
|||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -1,7 +1,16 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformer import Transformer, encoder_padding_mask
|
from transformer import (
|
||||||
|
Transformer,
|
||||||
|
encoder_padding_mask,
|
||||||
|
generate_square_subsequent_mask,
|
||||||
|
decoder_padding_mask,
|
||||||
|
add_sos,
|
||||||
|
add_eos,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
def test_encoder_padding_mask():
|
def test_encoder_padding_mask():
|
||||||
@ -34,3 +43,47 @@ def test_transformer():
|
|||||||
x = torch.rand(N, T, num_features)
|
x = torch.rand(N, T, num_features)
|
||||||
y, _, _ = model(x)
|
y, _, _ = model(x)
|
||||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_square_subsequent_mask():
|
||||||
|
s = 5
|
||||||
|
mask = generate_square_subsequent_mask(s)
|
||||||
|
inf = float("inf")
|
||||||
|
expected_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[0.0, -inf, -inf, -inf, -inf],
|
||||||
|
[0.0, 0.0, -inf, -inf, -inf],
|
||||||
|
[0.0, 0.0, 0.0, -inf, -inf],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(torch.eq(mask, expected_mask))
|
||||||
|
|
||||||
|
|
||||||
|
def test_decoder_padding_mask():
|
||||||
|
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
|
||||||
|
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||||
|
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||||
|
expected_mask = torch.tensor(
|
||||||
|
[
|
||||||
|
[False, False, True],
|
||||||
|
[False, True, True],
|
||||||
|
[False, False, False],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(torch.eq(mask, expected_mask))
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_sos():
|
||||||
|
x = [[1, 2], [3], [2, 5, 8]]
|
||||||
|
y = add_sos(x, sos_id=0)
|
||||||
|
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
|
||||||
|
assert y == expected_y
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_eos():
|
||||||
|
x = [[1, 2], [3], [2, 5, 8]]
|
||||||
|
y = add_eos(x, eos_id=0)
|
||||||
|
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
|
||||||
|
assert y == expected_y
|
||||||
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
|
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
Supervisions = Dict[str, torch.Tensor]
|
Supervisions = Dict[str, torch.Tensor]
|
||||||
@ -177,14 +178,17 @@ class Transformer(nn.Module):
|
|||||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||||
encoder_memory, memory_key_padding_mask = self.encode(x, supervision)
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
x = self.encoder_output(encoder_memory)
|
x, supervision
|
||||||
|
)
|
||||||
|
x = self.ctc_output(encoder_memory)
|
||||||
return x, encoder_memory, memory_key_padding_mask
|
return x, encoder_memory, memory_key_padding_mask
|
||||||
|
|
||||||
def encode(
|
def run_encoder(
|
||||||
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
|
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""Run the transformer encoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
The model input. Its shape is [N, T, C].
|
The model input. Its shape is [N, T, C].
|
||||||
@ -194,8 +198,8 @@ class Transformer(nn.Module):
|
|||||||
CAUTION: It contains length information, i.e., start and number of
|
CAUTION: It contains length information, i.e., start and number of
|
||||||
frames, before subsampling
|
frames, before subsampling
|
||||||
It is read directly from the batch, without any sorting. It is used
|
It is read directly from the batch, without any sorting. It is used
|
||||||
to compute encoder padding mask, which is used as memory key padding
|
to compute the encoder padding mask, which is used as memory key
|
||||||
mask for the decoder.
|
padding mask for the decoder.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple with two tensors:
|
Return a tuple with two tensors:
|
||||||
- The encoder output, with shape [T, N, C]
|
- The encoder output, with shape [T, N, C]
|
||||||
@ -212,7 +216,7 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
def encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -232,46 +236,16 @@ class Transformer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: torch.Tensor,
|
||||||
supervision: Optional[Supervisions] = None,
|
token_ids: List[List[int]],
|
||||||
L_inv: Optional[k2.Fsa] = None,
|
sos_id: int,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
eos_id: int,
|
||||||
oov_str: Optional[str] = None,
|
|
||||||
token_ids: List[List[int]] = None,
|
|
||||||
sos_id: Optional[int] = None,
|
|
||||||
eos_id: Optional[int] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Note:
|
|
||||||
If phone based lexicon is used, the following arguments are required:
|
|
||||||
|
|
||||||
- supervision
|
|
||||||
- L_inv
|
|
||||||
- word_table
|
|
||||||
- oov_str
|
|
||||||
|
|
||||||
If BPE based lexicon is used, the following arguments are required:
|
|
||||||
|
|
||||||
- token_ids
|
|
||||||
- sos_id
|
|
||||||
- eos_id
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory:
|
memory:
|
||||||
It's the output of the encoder with shape [T, N, C]
|
It's the output of the encoder with shape [T, N, C]
|
||||||
memory_key_padding_mask:
|
memory_key_padding_mask:
|
||||||
The padding mask from the encoder.
|
The padding mask from the encoder.
|
||||||
supervision:
|
|
||||||
Supervision in lhotse format.
|
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
|
||||||
(CAUTION: It contains length information, i.e., start and number of
|
|
||||||
frames, before subsampling)
|
|
||||||
L_inv:
|
|
||||||
It is an FSA with labels being word IDs and aux_labels being
|
|
||||||
token IDs (e.g., phone IDs or word piece IDs).
|
|
||||||
word_table:
|
|
||||||
Word table providing mapping between words and IDs.
|
|
||||||
oov_str:
|
|
||||||
The OOV word, e.g., '<UNK>'
|
|
||||||
token_ids:
|
token_ids:
|
||||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
||||||
The IDs can be either phone IDs or word piece IDs.
|
The IDs can be either phone IDs or word piece IDs.
|
||||||
@ -284,29 +258,13 @@ class Transformer(nn.Module):
|
|||||||
A scalar, the **sum** of label smoothing loss over utterances
|
A scalar, the **sum** of label smoothing loss over utterances
|
||||||
in the batch without any normalization.
|
in the batch without any normalization.
|
||||||
"""
|
"""
|
||||||
if supervision is not None and word_table is not None:
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
batch_text = get_normal_transcripts(
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
supervision, word_table, oov_str
|
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||||
)
|
|
||||||
ys_in_pad, ys_out_pad = add_sos_eos(
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
batch_text,
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
L_inv,
|
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||||
sos_id,
|
|
||||||
eos_id,
|
|
||||||
)
|
|
||||||
elif token_ids is not None:
|
|
||||||
_sos = torch.tensor([sos_id])
|
|
||||||
_eos = torch.tensor([eos_id])
|
|
||||||
ys_in = [
|
|
||||||
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
|
|
||||||
]
|
|
||||||
ys_out = [
|
|
||||||
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
|
|
||||||
]
|
|
||||||
ys_in_pad = pad_list(ys_in, eos_id)
|
|
||||||
ys_out_pad = pad_list(ys_out, -1)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid input for decoder self attention")
|
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device)
|
ys_in_pad = ys_in_pad.to(device)
|
||||||
@ -316,6 +274,8 @@ class Transformer(nn.Module):
|
|||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Use eos_id as ignore_id.
|
||||||
|
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||||
@ -362,19 +322,14 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# The common part between this function and decoder_forward could be
|
# The common part between this function and decoder_forward could be
|
||||||
# extracted as a separate function.
|
# extracted as a separate function.
|
||||||
if token_ids is not None:
|
|
||||||
_sos = torch.tensor([sos_id])
|
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||||
_eos = torch.tensor([eos_id])
|
ys_in = [torch.tensor(y) for y in ys_in]
|
||||||
ys_in = [
|
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||||
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
|
|
||||||
]
|
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||||
ys_out = [
|
ys_out = [torch.tensor(y) for y in ys_out]
|
||||||
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
|
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||||
]
|
|
||||||
ys_in_pad = pad_list(ys_in, eos_id)
|
|
||||||
ys_out_pad = pad_list(ys_out, -1)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid input for decoder self attention")
|
|
||||||
|
|
||||||
device = memory.device
|
device = memory.device
|
||||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||||
@ -384,6 +339,8 @@ class Transformer(nn.Module):
|
|||||||
device
|
device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Use eos_id as ignore_id.
|
||||||
|
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||||
|
|
||||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||||
@ -948,8 +905,8 @@ def decoder_padding_mask(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Generate a length mask for input.
|
"""Generate a length mask for input.
|
||||||
|
|
||||||
The masked position are filled with bool(True),
|
The masked position are filled with True,
|
||||||
Unmasked positions are filled with bool(False).
|
Unmasked positions are filled with False.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ys_pad:
|
ys_pad:
|
||||||
@ -965,45 +922,16 @@ def decoder_padding_mask(
|
|||||||
return ys_mask
|
return ys_mask
|
||||||
|
|
||||||
|
|
||||||
def get_normal_transcripts(
|
|
||||||
supervision: Supervisions, words: k2.SymbolTable, oov: str = "<UNK>"
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Get normal transcripts (1 input recording has 1 transcript)
|
|
||||||
from lhotse cut format.
|
|
||||||
|
|
||||||
Achieved by concatenating the transcripts corresponding to the
|
|
||||||
same recording.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
supervision:
|
|
||||||
Supervision in lhotse format.
|
|
||||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
|
||||||
words:
|
|
||||||
The word symbol table.
|
|
||||||
oov:
|
|
||||||
Out of vocabulary word.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[int]]: List of concatenated transcripts, length is batch_size
|
|
||||||
"""
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
[token if token in words else oov for token in text.split(" ")]
|
|
||||||
for text in supervision["text"]
|
|
||||||
]
|
|
||||||
texts_ids = [[words[token] for token in text] for text in texts]
|
|
||||||
|
|
||||||
batch_text = [
|
|
||||||
[] for _ in range(int(supervision["sequence_idx"].max().item()) + 1)
|
|
||||||
]
|
|
||||||
for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids):
|
|
||||||
batch_text[sequence_idx] = batch_text[sequence_idx] + text
|
|
||||||
return batch_text
|
|
||||||
|
|
||||||
|
|
||||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||||
"""Generate a square mask for the sequence. The masked positions are
|
"""Generate a square mask for the sequence. The masked positions are
|
||||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
||||||
|
The mask can be used for masked self-attention.
|
||||||
|
|
||||||
|
For instance, if sz is 3, it returns::
|
||||||
|
|
||||||
|
tensor([[0., -inf, -inf],
|
||||||
|
[0., 0., -inf],
|
||||||
|
[0., 0., 0]])
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sz: mask size
|
sz: mask size
|
||||||
@ -1020,115 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def add_sos_eos(
|
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
||||||
ys: List[List[int]],
|
"""Prepend sos_id to each utterance.
|
||||||
L_inv: k2.Fsa,
|
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
ignore_id: int = -1,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Add <sos> and <eos> labels.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ys:
|
token_ids:
|
||||||
Batch of unpadded target sequences (i.e., word IDs)
|
A list-of-list of token IDs. Each sublist contains
|
||||||
L_inv:
|
token IDs (e.g., word piece IDs) of an utterance.
|
||||||
Its labels are words, while its aux_labels are tokens.
|
|
||||||
sos_id:
|
sos_id:
|
||||||
index of <sos>
|
The ID of the SOS token.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Return a new list-of-list, where each sublist starts
|
||||||
|
with SOS ID.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for utt in token_ids:
|
||||||
|
ans.append([sos_id] + utt)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||||
|
"""Append eos_id to each utterance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids:
|
||||||
|
A list-of-list of token IDs. Each sublist contains
|
||||||
|
token IDs (e.g., word piece IDs) of an utterance.
|
||||||
eos_id:
|
eos_id:
|
||||||
index of <eos>
|
The ID of the EOS token.
|
||||||
ignore_id:
|
|
||||||
value for padding
|
|
||||||
|
|
||||||
Returns:
|
Return:
|
||||||
Return a tuple containing two tensors:
|
Return a new list-of-list, where each sublist ends
|
||||||
- Input of transformer decoder.
|
with EOS ID.
|
||||||
Padded tensor of dimension (batch_size, max_length).
|
|
||||||
- Output of transformer decoder.
|
|
||||||
Padded tensor of dimension (batch_size, max_length).
|
|
||||||
"""
|
"""
|
||||||
|
ans = []
|
||||||
_sos = torch.tensor([sos_id])
|
for utt in token_ids:
|
||||||
_eos = torch.tensor([eos_id])
|
ans.append(utt + [eos_id])
|
||||||
ys = get_hierarchical_targets(ys, L_inv)
|
return ans
|
||||||
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
|
||||||
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
|
||||||
return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor:
|
|
||||||
"""Perform padding for the list of tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ys:
|
|
||||||
List of tensors. len(ys) = batch_size.
|
|
||||||
pad_value:
|
|
||||||
Value for padding.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Padded tensor (batch_size, max_length, `*`).
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
|
||||||
>>> x
|
|
||||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
|
||||||
>>> pad_list(x, 0)
|
|
||||||
tensor([[1., 1., 1., 1.],
|
|
||||||
[1., 1., 0., 0.],
|
|
||||||
[1., 0., 0., 0.]])
|
|
||||||
|
|
||||||
"""
|
|
||||||
n_batch = len(ys)
|
|
||||||
max_len = max(x.size(0) for x in ys)
|
|
||||||
pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value)
|
|
||||||
|
|
||||||
for i in range(n_batch):
|
|
||||||
pad[i, : ys[i].size(0)] = ys[i]
|
|
||||||
|
|
||||||
return pad
|
|
||||||
|
|
||||||
|
|
||||||
def get_hierarchical_targets(
|
|
||||||
ys: List[List[int]], L_inv: Optional[k2.Fsa] = None
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""Get hierarchical transcripts (i.e., phone level transcripts) from
|
|
||||||
transcripts (i.e., word level transcripts).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ys:
|
|
||||||
Word level transcripts. Each sublist is a transcript of an utterance.
|
|
||||||
L_inv:
|
|
||||||
Its labels are words, while its aux_labels are tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[torch.Tensor]:
|
|
||||||
Token level transcripts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if L_inv is None:
|
|
||||||
return [torch.tensor(y) for y in ys]
|
|
||||||
|
|
||||||
device = L_inv.device
|
|
||||||
|
|
||||||
transcripts = k2.create_fsa_vec(
|
|
||||||
[k2.linear_fsa(x, device=device) for x in ys]
|
|
||||||
)
|
|
||||||
transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)
|
|
||||||
|
|
||||||
transcripts_lexicon = k2.intersect(
|
|
||||||
L_inv, transcripts_with_self_loops, treat_epsilons_specially=False
|
|
||||||
)
|
|
||||||
# Don't call invert_() above because we want to return phone IDs,
|
|
||||||
# which is the `aux_labels` of transcripts_lexicon
|
|
||||||
transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
|
|
||||||
transcripts_lexicon = k2.top_sort(transcripts_lexicon)
|
|
||||||
|
|
||||||
transcripts_lexicon = k2.shortest_path(
|
|
||||||
transcripts_lexicon, use_double_scores=True
|
|
||||||
)
|
|
||||||
|
|
||||||
ys = get_texts(transcripts_lexicon)
|
|
||||||
ys = [torch.tensor(y) for y in ys]
|
|
||||||
|
|
||||||
return ys
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user