mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
Reformatted streaming_decode.py with flake8
This commit is contained in:
parent
b574e68bf4
commit
9ab3021640
@ -22,13 +22,14 @@ Usage:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pdb
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
import subprocess as sp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from tokenizer import Tokenizer
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -42,6 +43,7 @@ from streaming_beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
from tokenizer import Tokenizer
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -61,9 +63,6 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
import subprocess as sp
|
|
||||||
import os
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
@ -518,7 +517,7 @@ def decode_one_chunk(
|
|||||||
decode_streams[i].states = states[i]
|
decode_streams[i].states = states[i]
|
||||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
# if decode_streams[i].done:
|
# if decode_streams[i].done:
|
||||||
# finished_streams.append(i)
|
# finished_streams.append(i)
|
||||||
finished_streams.append(i)
|
finished_streams.append(i)
|
||||||
|
|
||||||
return finished_streams
|
return finished_streams
|
||||||
@ -528,7 +527,7 @@ def decode_dataset(
|
|||||||
cuts: CutSet,
|
cuts: CutSet,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -540,7 +539,7 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
tokenizer:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
@ -608,7 +607,7 @@ def decode_dataset(
|
|||||||
(
|
(
|
||||||
decode_streams[i].id,
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
tokenizer.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del decode_streams[i]
|
del decode_streams[i]
|
||||||
@ -633,14 +632,13 @@ def decode_dataset(
|
|||||||
print("No finished streams, breaking the loop")
|
print("No finished streams, breaking the loop")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
try:
|
try:
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].id,
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
tokenizer.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del decode_streams[i]
|
del decode_streams[i]
|
||||||
@ -755,12 +753,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
sp_token = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp_token.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp_token.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp_token.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -870,7 +868,7 @@ def main():
|
|||||||
cuts=test_cut,
|
cuts=test_cut,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
tokenizer=sp_token,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
save_results(
|
save_results(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user