Add more documentation.

This commit is contained in:
Fangjun Kuang 2022-03-07 12:57:11 +08:00
parent 6bcfa6225f
commit d50e7734a6
2 changed files with 144 additions and 53 deletions

View File

@ -16,63 +16,88 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Iterator, List, Optional
import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
# TODO(fangjun): Add more documentation # The force alignment problem can be formulated as finding
# The force alignment problem can be formulated as find
# a path in a rectangular lattice, where the path starts # a path in a rectangular lattice, where the path starts
# from the lower left corner and ends at the upper right # from the lower left corner and ends at the upper right
# corner. The horizontal axis of the lattice is `t` # corner. The horizontal axis of the lattice is `t` (representing
# and the vertical axis is `u`. # acoustic frame indexes) and the vertical axis is `u` (representing
# BPE tokens of the transcript).
# #
# Beam search is used to find the path that with the
# highest log probabilities.
#
# It assumes that the maximum number of symbols that can be
# emitted per frame is 1. You can use `--modified-transducer-prob`
# from train.py to train a model that satisfy this assumption.
# AlignItem is a node in the lattice, where its # AlignItem is a node in the lattice, where its
# len(ys) equals to `t` and pos_u is the u coordinate # len(ys) equals to `t` and pos_u is the u coordinate
# in the lattice. # in the lattice.
@dataclass @dataclass
class AlignItem: class AlignItem:
# log prob of this
log_prob: float log_prob: float
# It contains framewise token alignment
ys: List[int] ys: List[int]
# It equals to number of non-zero entries in ys
pos_u: int pos_u: int
class AlignItemList: class AlignItemList:
def __init__(self, items: Optional[List[AlignItem]] = None): def __init__(self, items: Optional[List[AlignItem]] = None):
"""
Args:
items:
A list of AlignItem
"""
if items is None: if items is None:
items = [] items = []
self.data = items self.data = items
def __iter__(self): def __iter__(self) -> Iterator:
return iter(self.data) return iter(self.data)
def __len__(self): def __len__(self) -> int:
"""Return the number of AlignItem in this object."""
return len(self.data) return len(self.data)
def __getitem__(self, i: int) -> AlignItem: def __getitem__(self, i: int) -> AlignItem:
"""Return the i-th item in this object."""
return self.data[i] return self.data[i]
def append(self, item: AlignItem) -> None: def append(self, item: AlignItem) -> None:
"""Append an item to the end of this object."""
self.data.append(item) self.data.append(item)
def get_active_items(self, T: int, U: int) -> "AlignItemList":
ans = []
for item in self:
t = len(item.ys)
if U - item.pos_u > T - t:
continue
ans.append(item)
return AlignItemList(ans)
def get_decoder_input( def get_decoder_input(
self, self,
ys: List[int], ys: List[int],
context_size: int, context_size: int,
blank_id: int, blank_id: int,
) -> List[List[int]]: ) -> List[List[int]]:
"""Get input for the decoder for each item in this object.
Args:
ys:
The transcript of the utterance in BPE tokens.
context_size:
Context size of the NN decoder model.
blank_id:
The ID of the blank symbol.
Returns:
Return a list-of-list int. `ans[i]` contains the decoder
input for the i-th item in this object and its lengths
is `context_size`.
"""
ans: List[List[int]] = [] ans: List[List[int]] = []
buf = [blank_id] * context_size + ys buf = [blank_id] * context_size + ys
for item in self: for item in self:
@ -82,6 +107,18 @@ class AlignItemList:
return ans return ans
def topk(self, k: int) -> "AlignItemList": def topk(self, k: int) -> "AlignItemList":
"""Return the top-k items.
Items are ordered by their log probs in descending order
and the top-k items are returned.
Args:
k:
Size of top-k.
Returns:
Return a new AlignItemList that contains the top-k items
in this object. Caution: It uses shallow copy.
"""
items = list(self) items = list(self)
items = sorted(items, key=lambda i: i.log_prob, reverse=True) items = sorted(items, key=lambda i: i.log_prob, reverse=True)
return AlignItemList(items[:k]) return AlignItemList(items[:k])
@ -93,24 +130,28 @@ def force_alignment(
ys: List[int], ys: List[int],
beam_size: int = 4, beam_size: int = 4,
) -> List[int]: ) -> List[int]:
""" """Compute the force alignment of an utterance given its transcript
in BPE tokens and the corresponding acoustic output from the encoder.
Caution:
We assume that the maximum number of sybmols per frame is 1.
That is, the model should be training using a nonzero value
for the option `--modified-transducer-prob` in train.py.
Args: Args:
model: model:
The transducer model. The transducer model.
encoder_out: encoder_out:
A tensor of shape (N, T, C). Support only for N==1 now. A tensor of shape (N, T, C). Support only for N==1 at present.
ys: ys:
A list of token IDs. We require that len(ys) <= T. A list of BPE token IDs. We require that len(ys) <= T.
beam: beam_size:
Size of the beam used in beam search. Size of the beam used in beam search.
Returns: Returns:
Return a list of int such that Return a list of int such that
- len(ans) == T - len(ans) == T
- After removing blanks from ans, we have ans == ys. - After removing blanks from ans, we have ans == ys.
""" """
import pdb
pdb.set_trace()
assert encoder_out.ndim == 3, encoder_out.ndim assert encoder_out.ndim == 3, encoder_out.ndim
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1)) assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1))
@ -135,7 +176,6 @@ def force_alignment(
# current_encoder_out is of shape (1, 1, encoder_out_dim) # current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on # fmt: on
# A = B.get_active_items()
A = B # shallow copy A = B # shallow copy
B = AlignItemList() B = AlignItemList()
@ -184,4 +224,39 @@ def force_alignment(
if len(B) > beam_size: if len(B) > beam_size:
B = B.topk(beam_size) B = B.topk(beam_size)
return B.topk(1)[0].ys ans = B.topk(1)[0].ys
assert len(ans) == T
assert list(filter(lambda i: i != 0, ans)) == ys
return ans
def get_word_begin_frame(
ali: List[int], sp: spm.SentencePieceProcessor
) -> List[int]:
"""Get the beginning of each word from the given alignments.
When a word is encoded into BPE tokens, the first token starts
with underscore "_", which can be used to identify the beginning
of a word.
Args:
ali:
Framewise token alignment. It can be the return value of
:func:`force_alignment`.
sp:
The sentencepiece model.
Returns:
Return a list of int representing the starting frame of each word
in the alignment.
Caution:
You have to take into account the model subsampling factor when
converting the starting frame into time.
"""
underscore = b"\xe2\x96\x81".decode() # '_'
ans = []
for i in range(len(ali)):
if sp.id_to_piece(ali[i]).startswith(underscore):
ans.append(i)
return ans

View File

@ -137,20 +137,10 @@ def get_parser():
return parser return parser
def get_word_begin_time(ali: List[int], sp: spm.SentencePieceProcessor):
underscore = b"\xe2\x96\x81".decode() # '_'
ans = []
for i in range(len(ali)):
print(sp.id_to_piece(ali[i]))
if sp.id_to_piece(ali[i]).startswith(underscore):
print("yes")
ans.append(i * 0.04)
return ans
def compute_alignments( def compute_alignments(
model: torch.nn.Module, model: torch.nn.Module,
dl: torch.utils.data, dl: torch.utils.data,
ali_writer: FeaturesWriter,
params: AttributeDict, params: AttributeDict,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
): ):
@ -202,10 +192,30 @@ def compute_alignments(
beam_size=params.beam_size, beam_size=params.beam_size,
) )
ali_list.append(ali) ali_list.append(ali)
word_begin_time_list.append(get_word_begin_time(ali, sp)) assert len(ali_list) == len(cut_list)
import pdb
pdb.set_trace() for cut, ali in zip(cut_list, ali_list):
cut.token_alignment = ali_writer.store_array(
key=cut.id,
value=np.asarray(ali, dtype=np.int32),
# frame shift is 0.01s, subsampling_factor is 4
frame_shift=0.04,
temporal_dim=0,
start=0,
)
cuts += cut_list
num_cuts += len(cut_list)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return CutSet.from_cuts(cuts)
@torch.no_grad() @torch.no_grad()
@ -242,13 +252,11 @@ def main():
out_dir = Path(params.out_dir) out_dir = Path(params.out_dir)
out_dir.mkdir(exist_ok=True) out_dir.mkdir(exist_ok=True)
out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5"
out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5"
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
for f in ( for f in (
out_labels_ali_filename, out_ali_filename,
out_aux_labels_ali_filename,
out_manifest_filename, out_manifest_filename,
): ):
if f.exists(): if f.exists():
@ -305,18 +313,26 @@ def main():
logging.info(f"Processing {params.dataset}") logging.info(f"Processing {params.dataset}")
with NumpyHdf5Writer(out_ali_filename) as ali_writer:
cut_set = compute_alignments( cut_set = compute_alignments(
model=model, model=model,
dl=dl, dl=dl,
# labels_writer=labels_writer, ali_writer=ali_writer,
# aux_labels_writer=aux_labels_writer,
params=params, params=params,
sp=sp, sp=sp,
) )
cut_set.to_file(out_manifest_filename)
logging.info(
f"For dataset {params.dataset}, its framewise token alignments are "
f"saved to {out_ali_filename} and the cut manifest "
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
)
# torch.set_num_interop_threads(1)
# torch.set_num_threads(1) # torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()