Fix more typos.

This commit is contained in:
Fangjun Kuang 2022-03-07 16:29:13 +08:00
parent fb63ed627d
commit 2aca0d536c
2 changed files with 8 additions and 9 deletions

View File

@ -135,7 +135,7 @@ def force_alignment(
Caution:
We assume that the maximum number of sybmols per frame is 1.
That is, the model should be training using a nonzero value
That is, the model should be trained using a nonzero value
for the option `--modified-transducer-prob` in train.py.
Args:
@ -163,6 +163,7 @@ def force_alignment(
T = encoder_out.size(1)
U = len(ys)
assert 0 < U <= T
encoder_out_len = torch.tensor([1])
decoder_out_len = encoder_out_len
@ -204,7 +205,7 @@ def force_alignment(
for i, item in enumerate(A):
if (T - 1 - t) >= (U - item.pos_u):
# horizontal transition
# horizontal transition (left -> right)
new_item = AlignItem(
log_prob=item.log_prob + log_probs[i][blank_id],
ys=item.ys + [blank_id],
@ -213,7 +214,7 @@ def force_alignment(
B.append(new_item)
if item.pos_u < U:
# diagonal transition
# diagonal transition (lower left -> upper right)
u = ys[item.pos_u]
new_item = AlignItem(
log_prob=item.log_prob + log_probs[i][u],
@ -221,13 +222,14 @@ def force_alignment(
pos_u=item.pos_u + 1,
)
B.append(new_item)
if len(B) > beam_size:
B = B.topk(beam_size)
ans = B.topk(1)[0].ys
assert len(ans) == T
assert list(filter(lambda i: i != 0, ans)) == ys
assert list(filter(lambda i: i != blank_id, ans)) == ys
return ans
@ -235,7 +237,7 @@ def force_alignment(
def get_word_starting_frame(
ali: List[int], sp: spm.SentencePieceProcessor
) -> List[int]:
"""Get the starting frame of each word from the given alignments.
"""Get the starting frame of each word from the given token alignments.
When a word is encoded into BPE tokens, the first token starts
with underscore "_", which can be used to identify the starting frame

View File

@ -85,7 +85,7 @@ def get_parser():
type=str,
required=True,
help="""Output directory.
It contains 3 generated files:
It contains 2 generated files:
- token_ali_xxx.h5
- cuts_xxx.json.gz
@ -322,8 +322,5 @@ def main():
done_file.touch()
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()