mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Remove unused code.
This commit is contained in:
parent
f6091b10c0
commit
2be7a0a555
@ -84,7 +84,7 @@ class Conformer(Transformer):
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
def encode(
|
||||
def run_encoder(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
@ -802,7 +802,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
@ -872,7 +873,12 @@ class ConvolutionModule(nn.Module):
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = Swish()
|
||||
|
||||
|
@ -15,6 +15,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import (
|
||||
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "1best",
|
||||
"method": "attention-decoder",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
"num_paths": 100,
|
||||
@ -100,6 +101,8 @@ def decode_one_batch(
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -133,6 +136,10 @@ def decode_one_batch(
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -222,8 +229,8 @@ def decode_one_batch(
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=lexicon.sos_id,
|
||||
eos_id=lexicon.eos_id,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
@ -242,6 +249,8 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||
"""Decode dataset.
|
||||
@ -257,6 +266,10 @@ def decode_dataset(
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -284,6 +297,8 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
@ -364,6 +379,15 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
@ -456,6 +480,8 @@ def main():
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -1,7 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from transformer import Transformer, encoder_padding_mask
|
||||
from transformer import (
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
decoder_padding_mask,
|
||||
add_sos,
|
||||
add_eos,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
@ -34,3 +43,47 @@ def test_transformer():
|
||||
x = torch.rand(N, T, num_features)
|
||||
y, _, _ = model(x)
|
||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
||||
|
||||
|
||||
def test_generate_square_subsequent_mask():
|
||||
s = 5
|
||||
mask = generate_square_subsequent_mask(s)
|
||||
inf = float("inf")
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_decoder_padding_mask():
|
||||
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
|
||||
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, True],
|
||||
[False, True, True],
|
||||
[False, False, False],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_add_sos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_sos(x, sos_id=0)
|
||||
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
|
||||
assert y == expected_y
|
||||
|
||||
|
||||
def test_add_eos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_eos(x, eos_id=0)
|
||||
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
|
||||
assert y == expected_y
|
||||
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
|
||||
from icefall.utils import get_texts
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
@ -177,14 +178,17 @@ class Transformer(nn.Module):
|
||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
encoder_memory, memory_key_padding_mask = self.encode(x, supervision)
|
||||
x = self.encoder_output(encoder_memory)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
x = self.ctc_output(encoder_memory)
|
||||
return x, encoder_memory, memory_key_padding_mask
|
||||
|
||||
def encode(
|
||||
def run_encoder(
|
||||
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
"""Run the transformer encoder.
|
||||
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
@ -194,8 +198,8 @@ class Transformer(nn.Module):
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
to compute the encoder padding mask, which is used as memory key
|
||||
padding mask for the decoder.
|
||||
Returns:
|
||||
Return a tuple with two tensors:
|
||||
- The encoder output, with shape [T, N, C]
|
||||
@ -212,7 +216,7 @@ class Transformer(nn.Module):
|
||||
|
||||
return x, mask
|
||||
|
||||
def encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -232,46 +236,16 @@ class Transformer(nn.Module):
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
supervision: Optional[Supervisions] = None,
|
||||
L_inv: Optional[k2.Fsa] = None,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
oov_str: Optional[str] = None,
|
||||
token_ids: List[List[int]] = None,
|
||||
sos_id: Optional[int] = None,
|
||||
eos_id: Optional[int] = None,
|
||||
token_ids: List[List[int]],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Note:
|
||||
If phone based lexicon is used, the following arguments are required:
|
||||
|
||||
- supervision
|
||||
- L_inv
|
||||
- word_table
|
||||
- oov_str
|
||||
|
||||
If BPE based lexicon is used, the following arguments are required:
|
||||
|
||||
- token_ids
|
||||
- sos_id
|
||||
- eos_id
|
||||
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
(CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling)
|
||||
L_inv:
|
||||
It is an FSA with labels being word IDs and aux_labels being
|
||||
token IDs (e.g., phone IDs or word piece IDs).
|
||||
word_table:
|
||||
Word table providing mapping between words and IDs.
|
||||
oov_str:
|
||||
The OOV word, e.g., '<UNK>'
|
||||
token_ids:
|
||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
||||
The IDs can be either phone IDs or word piece IDs.
|
||||
@ -284,29 +258,13 @@ class Transformer(nn.Module):
|
||||
A scalar, the **sum** of label smoothing loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
if supervision is not None and word_table is not None:
|
||||
batch_text = get_normal_transcripts(
|
||||
supervision, word_table, oov_str
|
||||
)
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(
|
||||
batch_text,
|
||||
L_inv,
|
||||
sos_id,
|
||||
eos_id,
|
||||
)
|
||||
elif token_ids is not None:
|
||||
_sos = torch.tensor([sos_id])
|
||||
_eos = torch.tensor([eos_id])
|
||||
ys_in = [
|
||||
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
|
||||
]
|
||||
ys_out = [
|
||||
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
|
||||
]
|
||||
ys_in_pad = pad_list(ys_in, eos_id)
|
||||
ys_out_pad = pad_list(ys_out, -1)
|
||||
else:
|
||||
raise ValueError("Invalid input for decoder self attention")
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
@ -316,6 +274,8 @@ class Transformer(nn.Module):
|
||||
device
|
||||
)
|
||||
|
||||
# TODO: Use eos_id as ignore_id.
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||
@ -362,19 +322,14 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
if token_ids is not None:
|
||||
_sos = torch.tensor([sos_id])
|
||||
_eos = torch.tensor([eos_id])
|
||||
ys_in = [
|
||||
torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids
|
||||
]
|
||||
ys_out = [
|
||||
torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids
|
||||
]
|
||||
ys_in_pad = pad_list(ys_in, eos_id)
|
||||
ys_out_pad = pad_list(ys_out, -1)
|
||||
else:
|
||||
raise ValueError("Invalid input for decoder self attention")
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
@ -384,6 +339,8 @@ class Transformer(nn.Module):
|
||||
device
|
||||
)
|
||||
|
||||
# TODO: Use eos_id as ignore_id.
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||
@ -948,8 +905,8 @@ def decoder_padding_mask(
|
||||
) -> torch.Tensor:
|
||||
"""Generate a length mask for input.
|
||||
|
||||
The masked position are filled with bool(True),
|
||||
Unmasked positions are filled with bool(False).
|
||||
The masked position are filled with True,
|
||||
Unmasked positions are filled with False.
|
||||
|
||||
Args:
|
||||
ys_pad:
|
||||
@ -965,45 +922,16 @@ def decoder_padding_mask(
|
||||
return ys_mask
|
||||
|
||||
|
||||
def get_normal_transcripts(
|
||||
supervision: Supervisions, words: k2.SymbolTable, oov: str = "<UNK>"
|
||||
) -> List[List[int]]:
|
||||
"""Get normal transcripts (1 input recording has 1 transcript)
|
||||
from lhotse cut format.
|
||||
|
||||
Achieved by concatenating the transcripts corresponding to the
|
||||
same recording.
|
||||
|
||||
Args:
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
words:
|
||||
The word symbol table.
|
||||
oov:
|
||||
Out of vocabulary word.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: List of concatenated transcripts, length is batch_size
|
||||
"""
|
||||
|
||||
texts = [
|
||||
[token if token in words else oov for token in text.split(" ")]
|
||||
for text in supervision["text"]
|
||||
]
|
||||
texts_ids = [[words[token] for token in text] for text in texts]
|
||||
|
||||
batch_text = [
|
||||
[] for _ in range(int(supervision["sequence_idx"].max().item()) + 1)
|
||||
]
|
||||
for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids):
|
||||
batch_text[sequence_idx] = batch_text[sequence_idx] + text
|
||||
return batch_text
|
||||
|
||||
|
||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||
"""Generate a square mask for the sequence. The masked positions are
|
||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
||||
The mask can be used for masked self-attention.
|
||||
|
||||
For instance, if sz is 3, it returns::
|
||||
|
||||
tensor([[0., -inf, -inf],
|
||||
[0., 0., -inf],
|
||||
[0., 0., 0]])
|
||||
|
||||
Args:
|
||||
sz: mask size
|
||||
@ -1020,115 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||
return mask
|
||||
|
||||
|
||||
def add_sos_eos(
|
||||
ys: List[List[int]],
|
||||
L_inv: k2.Fsa,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
ignore_id: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add <sos> and <eos> labels.
|
||||
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
||||
"""Prepend sos_id to each utterance.
|
||||
|
||||
Args:
|
||||
ys:
|
||||
Batch of unpadded target sequences (i.e., word IDs)
|
||||
L_inv:
|
||||
Its labels are words, while its aux_labels are tokens.
|
||||
token_ids:
|
||||
A list-of-list of token IDs. Each sublist contains
|
||||
token IDs (e.g., word piece IDs) of an utterance.
|
||||
sos_id:
|
||||
index of <sos>
|
||||
The ID of the SOS token.
|
||||
|
||||
Return:
|
||||
Return a new list-of-list, where each sublist starts
|
||||
with SOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append([sos_id] + utt)
|
||||
return ans
|
||||
|
||||
|
||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
"""Append eos_id to each utterance.
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
A list-of-list of token IDs. Each sublist contains
|
||||
token IDs (e.g., word piece IDs) of an utterance.
|
||||
eos_id:
|
||||
index of <eos>
|
||||
ignore_id:
|
||||
value for padding
|
||||
The ID of the EOS token.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- Input of transformer decoder.
|
||||
Padded tensor of dimension (batch_size, max_length).
|
||||
- Output of transformer decoder.
|
||||
Padded tensor of dimension (batch_size, max_length).
|
||||
Return:
|
||||
Return a new list-of-list, where each sublist ends
|
||||
with EOS ID.
|
||||
"""
|
||||
|
||||
_sos = torch.tensor([sos_id])
|
||||
_eos = torch.tensor([eos_id])
|
||||
ys = get_hierarchical_targets(ys, L_inv)
|
||||
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
||||
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
||||
return pad_list(ys_in, eos_id), pad_list(ys_out, ignore_id)
|
||||
|
||||
|
||||
def pad_list(ys: List[torch.Tensor], pad_value: float) -> torch.Tensor:
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
ys:
|
||||
List of tensors. len(ys) = batch_size.
|
||||
pad_value:
|
||||
Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (batch_size, max_length, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(ys)
|
||||
max_len = max(x.size(0) for x in ys)
|
||||
pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, : ys[i].size(0)] = ys[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def get_hierarchical_targets(
|
||||
ys: List[List[int]], L_inv: Optional[k2.Fsa] = None
|
||||
) -> List[torch.Tensor]:
|
||||
"""Get hierarchical transcripts (i.e., phone level transcripts) from
|
||||
transcripts (i.e., word level transcripts).
|
||||
|
||||
Args:
|
||||
ys:
|
||||
Word level transcripts. Each sublist is a transcript of an utterance.
|
||||
L_inv:
|
||||
Its labels are words, while its aux_labels are tokens.
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]:
|
||||
Token level transcripts.
|
||||
"""
|
||||
|
||||
if L_inv is None:
|
||||
return [torch.tensor(y) for y in ys]
|
||||
|
||||
device = L_inv.device
|
||||
|
||||
transcripts = k2.create_fsa_vec(
|
||||
[k2.linear_fsa(x, device=device) for x in ys]
|
||||
)
|
||||
transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)
|
||||
|
||||
transcripts_lexicon = k2.intersect(
|
||||
L_inv, transcripts_with_self_loops, treat_epsilons_specially=False
|
||||
)
|
||||
# Don't call invert_() above because we want to return phone IDs,
|
||||
# which is the `aux_labels` of transcripts_lexicon
|
||||
transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
|
||||
transcripts_lexicon = k2.top_sort(transcripts_lexicon)
|
||||
|
||||
transcripts_lexicon = k2.shortest_path(
|
||||
transcripts_lexicon, use_double_scores=True
|
||||
)
|
||||
|
||||
ys = get_texts(transcripts_lexicon)
|
||||
ys = [torch.tensor(y) for y in ys]
|
||||
|
||||
return ys
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append(utt + [eos_id])
|
||||
return ans
|
||||
|
Loading…
x
Reference in New Issue
Block a user