mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
Fix more typos.
This commit is contained in:
parent
fb63ed627d
commit
2aca0d536c
@ -135,7 +135,7 @@ def force_alignment(
|
|||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
We assume that the maximum number of sybmols per frame is 1.
|
We assume that the maximum number of sybmols per frame is 1.
|
||||||
That is, the model should be training using a nonzero value
|
That is, the model should be trained using a nonzero value
|
||||||
for the option `--modified-transducer-prob` in train.py.
|
for the option `--modified-transducer-prob` in train.py.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -163,6 +163,7 @@ def force_alignment(
|
|||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
U = len(ys)
|
U = len(ys)
|
||||||
|
assert 0 < U <= T
|
||||||
|
|
||||||
encoder_out_len = torch.tensor([1])
|
encoder_out_len = torch.tensor([1])
|
||||||
decoder_out_len = encoder_out_len
|
decoder_out_len = encoder_out_len
|
||||||
@ -204,7 +205,7 @@ def force_alignment(
|
|||||||
|
|
||||||
for i, item in enumerate(A):
|
for i, item in enumerate(A):
|
||||||
if (T - 1 - t) >= (U - item.pos_u):
|
if (T - 1 - t) >= (U - item.pos_u):
|
||||||
# horizontal transition
|
# horizontal transition (left -> right)
|
||||||
new_item = AlignItem(
|
new_item = AlignItem(
|
||||||
log_prob=item.log_prob + log_probs[i][blank_id],
|
log_prob=item.log_prob + log_probs[i][blank_id],
|
||||||
ys=item.ys + [blank_id],
|
ys=item.ys + [blank_id],
|
||||||
@ -213,7 +214,7 @@ def force_alignment(
|
|||||||
B.append(new_item)
|
B.append(new_item)
|
||||||
|
|
||||||
if item.pos_u < U:
|
if item.pos_u < U:
|
||||||
# diagonal transition
|
# diagonal transition (lower left -> upper right)
|
||||||
u = ys[item.pos_u]
|
u = ys[item.pos_u]
|
||||||
new_item = AlignItem(
|
new_item = AlignItem(
|
||||||
log_prob=item.log_prob + log_probs[i][u],
|
log_prob=item.log_prob + log_probs[i][u],
|
||||||
@ -221,13 +222,14 @@ def force_alignment(
|
|||||||
pos_u=item.pos_u + 1,
|
pos_u=item.pos_u + 1,
|
||||||
)
|
)
|
||||||
B.append(new_item)
|
B.append(new_item)
|
||||||
|
|
||||||
if len(B) > beam_size:
|
if len(B) > beam_size:
|
||||||
B = B.topk(beam_size)
|
B = B.topk(beam_size)
|
||||||
|
|
||||||
ans = B.topk(1)[0].ys
|
ans = B.topk(1)[0].ys
|
||||||
|
|
||||||
assert len(ans) == T
|
assert len(ans) == T
|
||||||
assert list(filter(lambda i: i != 0, ans)) == ys
|
assert list(filter(lambda i: i != blank_id, ans)) == ys
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -235,7 +237,7 @@ def force_alignment(
|
|||||||
def get_word_starting_frame(
|
def get_word_starting_frame(
|
||||||
ali: List[int], sp: spm.SentencePieceProcessor
|
ali: List[int], sp: spm.SentencePieceProcessor
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Get the starting frame of each word from the given alignments.
|
"""Get the starting frame of each word from the given token alignments.
|
||||||
|
|
||||||
When a word is encoded into BPE tokens, the first token starts
|
When a word is encoded into BPE tokens, the first token starts
|
||||||
with underscore "_", which can be used to identify the starting frame
|
with underscore "_", which can be used to identify the starting frame
|
||||||
|
@ -85,7 +85,7 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="""Output directory.
|
help="""Output directory.
|
||||||
It contains 3 generated files:
|
It contains 2 generated files:
|
||||||
|
|
||||||
- token_ali_xxx.h5
|
- token_ali_xxx.h5
|
||||||
- cuts_xxx.json.gz
|
- cuts_xxx.json.gz
|
||||||
@ -322,8 +322,5 @@ def main():
|
|||||||
done_file.touch()
|
done_file.touch()
|
||||||
|
|
||||||
|
|
||||||
# torch.set_num_threads(1)
|
|
||||||
# torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user