Fix more typos.

This commit is contained in:
Fangjun Kuang 2022-03-07 16:29:13 +08:00
parent fb63ed627d
commit 2aca0d536c
2 changed files with 8 additions and 9 deletions

View File

@ -135,7 +135,7 @@ def force_alignment(
Caution: Caution:
We assume that the maximum number of sybmols per frame is 1. We assume that the maximum number of sybmols per frame is 1.
That is, the model should be training using a nonzero value That is, the model should be trained using a nonzero value
for the option `--modified-transducer-prob` in train.py. for the option `--modified-transducer-prob` in train.py.
Args: Args:
@ -163,6 +163,7 @@ def force_alignment(
T = encoder_out.size(1) T = encoder_out.size(1)
U = len(ys) U = len(ys)
assert 0 < U <= T
encoder_out_len = torch.tensor([1]) encoder_out_len = torch.tensor([1])
decoder_out_len = encoder_out_len decoder_out_len = encoder_out_len
@ -204,7 +205,7 @@ def force_alignment(
for i, item in enumerate(A): for i, item in enumerate(A):
if (T - 1 - t) >= (U - item.pos_u): if (T - 1 - t) >= (U - item.pos_u):
# horizontal transition # horizontal transition (left -> right)
new_item = AlignItem( new_item = AlignItem(
log_prob=item.log_prob + log_probs[i][blank_id], log_prob=item.log_prob + log_probs[i][blank_id],
ys=item.ys + [blank_id], ys=item.ys + [blank_id],
@ -213,7 +214,7 @@ def force_alignment(
B.append(new_item) B.append(new_item)
if item.pos_u < U: if item.pos_u < U:
# diagonal transition # diagonal transition (lower left -> upper right)
u = ys[item.pos_u] u = ys[item.pos_u]
new_item = AlignItem( new_item = AlignItem(
log_prob=item.log_prob + log_probs[i][u], log_prob=item.log_prob + log_probs[i][u],
@ -221,13 +222,14 @@ def force_alignment(
pos_u=item.pos_u + 1, pos_u=item.pos_u + 1,
) )
B.append(new_item) B.append(new_item)
if len(B) > beam_size: if len(B) > beam_size:
B = B.topk(beam_size) B = B.topk(beam_size)
ans = B.topk(1)[0].ys ans = B.topk(1)[0].ys
assert len(ans) == T assert len(ans) == T
assert list(filter(lambda i: i != 0, ans)) == ys assert list(filter(lambda i: i != blank_id, ans)) == ys
return ans return ans
@ -235,7 +237,7 @@ def force_alignment(
def get_word_starting_frame( def get_word_starting_frame(
ali: List[int], sp: spm.SentencePieceProcessor ali: List[int], sp: spm.SentencePieceProcessor
) -> List[int]: ) -> List[int]:
"""Get the starting frame of each word from the given alignments. """Get the starting frame of each word from the given token alignments.
When a word is encoded into BPE tokens, the first token starts When a word is encoded into BPE tokens, the first token starts
with underscore "_", which can be used to identify the starting frame with underscore "_", which can be used to identify the starting frame

View File

@ -85,7 +85,7 @@ def get_parser():
type=str, type=str,
required=True, required=True,
help="""Output directory. help="""Output directory.
It contains 3 generated files: It contains 2 generated files:
- token_ali_xxx.h5 - token_ali_xxx.h5
- cuts_xxx.json.gz - cuts_xxx.json.gz
@ -322,8 +322,5 @@ def main():
done_file.touch() done_file.touch()
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()