mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Add more documentation.
This commit is contained in:
parent
6bcfa6225f
commit
d50e7734a6
@ -16,63 +16,88 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
# TODO(fangjun): Add more documentation
|
||||
|
||||
# The force alignment problem can be formulated as find
|
||||
# The force alignment problem can be formulated as finding
|
||||
# a path in a rectangular lattice, where the path starts
|
||||
# from the lower left corner and ends at the upper right
|
||||
# corner. The horizontal axis of the lattice is `t`
|
||||
# and the vertical axis is `u`.
|
||||
# corner. The horizontal axis of the lattice is `t` (representing
|
||||
# acoustic frame indexes) and the vertical axis is `u` (representing
|
||||
# BPE tokens of the transcript).
|
||||
#
|
||||
# Beam search is used to find the path that with the
|
||||
# highest log probabilities.
|
||||
#
|
||||
# It assumes that the maximum number of symbols that can be
|
||||
# emitted per frame is 1. You can use `--modified-transducer-prob`
|
||||
# from train.py to train a model that satisfy this assumption.
|
||||
|
||||
|
||||
# AlignItem is a node in the lattice, where its
|
||||
# len(ys) equals to `t` and pos_u is the u coordinate
|
||||
# in the lattice.
|
||||
@dataclass
|
||||
class AlignItem:
|
||||
# log prob of this
|
||||
log_prob: float
|
||||
|
||||
# It contains framewise token alignment
|
||||
ys: List[int]
|
||||
|
||||
# It equals to number of non-zero entries in ys
|
||||
pos_u: int
|
||||
|
||||
|
||||
class AlignItemList:
|
||||
def __init__(self, items: Optional[List[AlignItem]] = None):
|
||||
"""
|
||||
Args:
|
||||
items:
|
||||
A list of AlignItem
|
||||
"""
|
||||
if items is None:
|
||||
items = []
|
||||
self.data = items
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.data)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of AlignItem in this object."""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, i: int) -> AlignItem:
|
||||
"""Return the i-th item in this object."""
|
||||
return self.data[i]
|
||||
|
||||
def append(self, item: AlignItem) -> None:
|
||||
"""Append an item to the end of this object."""
|
||||
self.data.append(item)
|
||||
|
||||
def get_active_items(self, T: int, U: int) -> "AlignItemList":
|
||||
ans = []
|
||||
for item in self:
|
||||
t = len(item.ys)
|
||||
if U - item.pos_u > T - t:
|
||||
continue
|
||||
ans.append(item)
|
||||
|
||||
return AlignItemList(ans)
|
||||
|
||||
def get_decoder_input(
|
||||
self,
|
||||
ys: List[int],
|
||||
context_size: int,
|
||||
blank_id: int,
|
||||
) -> List[List[int]]:
|
||||
"""Get input for the decoder for each item in this object.
|
||||
|
||||
Args:
|
||||
ys:
|
||||
The transcript of the utterance in BPE tokens.
|
||||
context_size:
|
||||
Context size of the NN decoder model.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
Returns:
|
||||
Return a list-of-list int. `ans[i]` contains the decoder
|
||||
input for the i-th item in this object and its lengths
|
||||
is `context_size`.
|
||||
"""
|
||||
ans: List[List[int]] = []
|
||||
buf = [blank_id] * context_size + ys
|
||||
for item in self:
|
||||
@ -82,6 +107,18 @@ class AlignItemList:
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "AlignItemList":
|
||||
"""Return the top-k items.
|
||||
|
||||
Items are ordered by their log probs in descending order
|
||||
and the top-k items are returned.
|
||||
|
||||
Args:
|
||||
k:
|
||||
Size of top-k.
|
||||
Returns:
|
||||
Return a new AlignItemList that contains the top-k items
|
||||
in this object. Caution: It uses shallow copy.
|
||||
"""
|
||||
items = list(self)
|
||||
items = sorted(items, key=lambda i: i.log_prob, reverse=True)
|
||||
return AlignItemList(items[:k])
|
||||
@ -93,24 +130,28 @@ def force_alignment(
|
||||
ys: List[int],
|
||||
beam_size: int = 4,
|
||||
) -> List[int]:
|
||||
"""
|
||||
"""Compute the force alignment of an utterance given its transcript
|
||||
in BPE tokens and the corresponding acoustic output from the encoder.
|
||||
|
||||
Caution:
|
||||
We assume that the maximum number of sybmols per frame is 1.
|
||||
That is, the model should be training using a nonzero value
|
||||
for the option `--modified-transducer-prob` in train.py.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C). Support only for N==1 now.
|
||||
A tensor of shape (N, T, C). Support only for N==1 at present.
|
||||
ys:
|
||||
A list of token IDs. We require that len(ys) <= T.
|
||||
beam:
|
||||
A list of BPE token IDs. We require that len(ys) <= T.
|
||||
beam_size:
|
||||
Size of the beam used in beam search.
|
||||
Returns:
|
||||
Return a list of int such that
|
||||
- len(ans) == T
|
||||
- After removing blanks from ans, we have ans == ys.
|
||||
"""
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
assert encoder_out.ndim == 3, encoder_out.ndim
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1))
|
||||
@ -135,7 +176,6 @@ def force_alignment(
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
|
||||
# A = B.get_active_items()
|
||||
A = B # shallow copy
|
||||
B = AlignItemList()
|
||||
|
||||
@ -184,4 +224,39 @@ def force_alignment(
|
||||
if len(B) > beam_size:
|
||||
B = B.topk(beam_size)
|
||||
|
||||
return B.topk(1)[0].ys
|
||||
ans = B.topk(1)[0].ys
|
||||
|
||||
assert len(ans) == T
|
||||
assert list(filter(lambda i: i != 0, ans)) == ys
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def get_word_begin_frame(
|
||||
ali: List[int], sp: spm.SentencePieceProcessor
|
||||
) -> List[int]:
|
||||
"""Get the beginning of each word from the given alignments.
|
||||
|
||||
When a word is encoded into BPE tokens, the first token starts
|
||||
with underscore "_", which can be used to identify the beginning
|
||||
of a word.
|
||||
|
||||
Args:
|
||||
ali:
|
||||
Framewise token alignment. It can be the return value of
|
||||
:func:`force_alignment`.
|
||||
sp:
|
||||
The sentencepiece model.
|
||||
Returns:
|
||||
Return a list of int representing the starting frame of each word
|
||||
in the alignment.
|
||||
Caution:
|
||||
You have to take into account the model subsampling factor when
|
||||
converting the starting frame into time.
|
||||
"""
|
||||
underscore = b"\xe2\x96\x81".decode() # '_'
|
||||
ans = []
|
||||
for i in range(len(ali)):
|
||||
if sp.id_to_piece(ali[i]).startswith(underscore):
|
||||
ans.append(i)
|
||||
return ans
|
||||
|
@ -137,20 +137,10 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def get_word_begin_time(ali: List[int], sp: spm.SentencePieceProcessor):
|
||||
underscore = b"\xe2\x96\x81".decode() # '_'
|
||||
ans = []
|
||||
for i in range(len(ali)):
|
||||
print(sp.id_to_piece(ali[i]))
|
||||
if sp.id_to_piece(ali[i]).startswith(underscore):
|
||||
print("yes")
|
||||
ans.append(i * 0.04)
|
||||
return ans
|
||||
|
||||
|
||||
def compute_alignments(
|
||||
model: torch.nn.Module,
|
||||
dl: torch.utils.data,
|
||||
ali_writer: FeaturesWriter,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
@ -202,10 +192,30 @@ def compute_alignments(
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
ali_list.append(ali)
|
||||
word_begin_time_list.append(get_word_begin_time(ali, sp))
|
||||
import pdb
|
||||
assert len(ali_list) == len(cut_list)
|
||||
|
||||
pdb.set_trace()
|
||||
for cut, ali in zip(cut_list, ali_list):
|
||||
cut.token_alignment = ali_writer.store_array(
|
||||
key=cut.id,
|
||||
value=np.asarray(ali, dtype=np.int32),
|
||||
# frame shift is 0.01s, subsampling_factor is 4
|
||||
frame_shift=0.04,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
|
||||
cuts += cut_list
|
||||
|
||||
num_cuts += len(cut_list)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
return CutSet.from_cuts(cuts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -242,13 +252,11 @@ def main():
|
||||
out_dir = Path(params.out_dir)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
|
||||
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
|
||||
out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5"
|
||||
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
|
||||
|
||||
for f in (
|
||||
out_labels_ali_filename,
|
||||
out_aux_labels_ali_filename,
|
||||
out_ali_filename,
|
||||
out_manifest_filename,
|
||||
):
|
||||
if f.exists():
|
||||
@ -305,18 +313,26 @@ def main():
|
||||
|
||||
logging.info(f"Processing {params.dataset}")
|
||||
|
||||
cut_set = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
# labels_writer=labels_writer,
|
||||
# aux_labels_writer=aux_labels_writer,
|
||||
params=params,
|
||||
sp=sp,
|
||||
with NumpyHdf5Writer(out_ali_filename) as ali_writer:
|
||||
cut_set = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
ali_writer=ali_writer,
|
||||
params=params,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
cut_set.to_file(out_manifest_filename)
|
||||
|
||||
logging.info(
|
||||
f"For dataset {params.dataset}, its framewise token alignments are "
|
||||
f"saved to {out_ali_filename} and the cut manifest "
|
||||
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||
)
|
||||
|
||||
|
||||
# torch.set_num_interop_threads(1)
|
||||
# torch.set_num_threads(1)
|
||||
# torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user