Add more documentation.

This commit is contained in:
Fangjun Kuang 2022-03-07 12:57:11 +08:00
parent 6bcfa6225f
commit d50e7734a6
2 changed files with 144 additions and 53 deletions

View File

@ -16,63 +16,88 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Iterator, List, Optional
import sentencepiece as spm
import torch
from model import Transducer
# TODO(fangjun): Add more documentation
# The force alignment problem can be formulated as find
# The force alignment problem can be formulated as finding
# a path in a rectangular lattice, where the path starts
# from the lower left corner and ends at the upper right
# corner. The horizontal axis of the lattice is `t`
# and the vertical axis is `u`.
# corner. The horizontal axis of the lattice is `t` (representing
# acoustic frame indexes) and the vertical axis is `u` (representing
# BPE tokens of the transcript).
#
# Beam search is used to find the path that with the
# highest log probabilities.
#
# It assumes that the maximum number of symbols that can be
# emitted per frame is 1. You can use `--modified-transducer-prob`
# from train.py to train a model that satisfy this assumption.
# AlignItem is a node in the lattice, where its
# len(ys) equals to `t` and pos_u is the u coordinate
# in the lattice.
@dataclass
class AlignItem:
# log prob of this
log_prob: float
# It contains framewise token alignment
ys: List[int]
# It equals to number of non-zero entries in ys
pos_u: int
class AlignItemList:
def __init__(self, items: Optional[List[AlignItem]] = None):
"""
Args:
items:
A list of AlignItem
"""
if items is None:
items = []
self.data = items
def __iter__(self):
def __iter__(self) -> Iterator:
return iter(self.data)
def __len__(self):
def __len__(self) -> int:
"""Return the number of AlignItem in this object."""
return len(self.data)
def __getitem__(self, i: int) -> AlignItem:
"""Return the i-th item in this object."""
return self.data[i]
def append(self, item: AlignItem) -> None:
"""Append an item to the end of this object."""
self.data.append(item)
def get_active_items(self, T: int, U: int) -> "AlignItemList":
ans = []
for item in self:
t = len(item.ys)
if U - item.pos_u > T - t:
continue
ans.append(item)
return AlignItemList(ans)
def get_decoder_input(
self,
ys: List[int],
context_size: int,
blank_id: int,
) -> List[List[int]]:
"""Get input for the decoder for each item in this object.
Args:
ys:
The transcript of the utterance in BPE tokens.
context_size:
Context size of the NN decoder model.
blank_id:
The ID of the blank symbol.
Returns:
Return a list-of-list int. `ans[i]` contains the decoder
input for the i-th item in this object and its lengths
is `context_size`.
"""
ans: List[List[int]] = []
buf = [blank_id] * context_size + ys
for item in self:
@ -82,6 +107,18 @@ class AlignItemList:
return ans
def topk(self, k: int) -> "AlignItemList":
"""Return the top-k items.
Items are ordered by their log probs in descending order
and the top-k items are returned.
Args:
k:
Size of top-k.
Returns:
Return a new AlignItemList that contains the top-k items
in this object. Caution: It uses shallow copy.
"""
items = list(self)
items = sorted(items, key=lambda i: i.log_prob, reverse=True)
return AlignItemList(items[:k])
@ -93,24 +130,28 @@ def force_alignment(
ys: List[int],
beam_size: int = 4,
) -> List[int]:
"""
"""Compute the force alignment of an utterance given its transcript
in BPE tokens and the corresponding acoustic output from the encoder.
Caution:
We assume that the maximum number of sybmols per frame is 1.
That is, the model should be training using a nonzero value
for the option `--modified-transducer-prob` in train.py.
Args:
model:
The transducer model.
encoder_out:
A tensor of shape (N, T, C). Support only for N==1 now.
A tensor of shape (N, T, C). Support only for N==1 at present.
ys:
A list of token IDs. We require that len(ys) <= T.
beam:
A list of BPE token IDs. We require that len(ys) <= T.
beam_size:
Size of the beam used in beam search.
Returns:
Return a list of int such that
- len(ans) == T
- After removing blanks from ans, we have ans == ys.
"""
import pdb
pdb.set_trace()
assert encoder_out.ndim == 3, encoder_out.ndim
assert encoder_out.size(0) == 1, encoder_out.size(0)
assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1))
@ -135,7 +176,6 @@ def force_alignment(
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
# A = B.get_active_items()
A = B # shallow copy
B = AlignItemList()
@ -184,4 +224,39 @@ def force_alignment(
if len(B) > beam_size:
B = B.topk(beam_size)
return B.topk(1)[0].ys
ans = B.topk(1)[0].ys
assert len(ans) == T
assert list(filter(lambda i: i != 0, ans)) == ys
return ans
def get_word_begin_frame(
ali: List[int], sp: spm.SentencePieceProcessor
) -> List[int]:
"""Get the beginning of each word from the given alignments.
When a word is encoded into BPE tokens, the first token starts
with underscore "_", which can be used to identify the beginning
of a word.
Args:
ali:
Framewise token alignment. It can be the return value of
:func:`force_alignment`.
sp:
The sentencepiece model.
Returns:
Return a list of int representing the starting frame of each word
in the alignment.
Caution:
You have to take into account the model subsampling factor when
converting the starting frame into time.
"""
underscore = b"\xe2\x96\x81".decode() # '_'
ans = []
for i in range(len(ali)):
if sp.id_to_piece(ali[i]).startswith(underscore):
ans.append(i)
return ans

View File

@ -137,20 +137,10 @@ def get_parser():
return parser
def get_word_begin_time(ali: List[int], sp: spm.SentencePieceProcessor):
underscore = b"\xe2\x96\x81".decode() # '_'
ans = []
for i in range(len(ali)):
print(sp.id_to_piece(ali[i]))
if sp.id_to_piece(ali[i]).startswith(underscore):
print("yes")
ans.append(i * 0.04)
return ans
def compute_alignments(
model: torch.nn.Module,
dl: torch.utils.data,
ali_writer: FeaturesWriter,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
):
@ -202,10 +192,30 @@ def compute_alignments(
beam_size=params.beam_size,
)
ali_list.append(ali)
word_begin_time_list.append(get_word_begin_time(ali, sp))
import pdb
assert len(ali_list) == len(cut_list)
pdb.set_trace()
for cut, ali in zip(cut_list, ali_list):
cut.token_alignment = ali_writer.store_array(
key=cut.id,
value=np.asarray(ali, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(cut_list)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return CutSet.from_cuts(cuts)
@torch.no_grad()
@ -242,13 +252,11 @@ def main():
out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True)
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5"
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
for f in (
out_labels_ali_filename,
out_aux_labels_ali_filename,
out_ali_filename,
out_manifest_filename,
):
if f.exists():
@ -305,18 +313,26 @@ def main():
logging.info(f"Processing {params.dataset}")
cut_set = compute_alignments(
model=model,
dl=dl,
# labels_writer=labels_writer,
# aux_labels_writer=aux_labels_writer,
params=params,
sp=sp,
with NumpyHdf5Writer(out_ali_filename) as ali_writer:
cut_set = compute_alignments(
model=model,
dl=dl,
ali_writer=ali_writer,
params=params,
sp=sp,
)
cut_set.to_file(out_manifest_filename)
logging.info(
f"For dataset {params.dataset}, its framewise token alignments are "
f"saved to {out_ali_filename} and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
)
# torch.set_num_interop_threads(1)
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()