mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Reformatted streaming_decode.py with flake8
This commit is contained in:
parent
b574e68bf4
commit
9ab3021640
@ -22,13 +22,14 @@ Usage:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pdb
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
import subprocess as sp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
from tokenizer import Tokenizer
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -42,6 +43,7 @@ from streaming_beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
from tokenizer import Tokenizer
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -61,9 +63,6 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
import subprocess as sp
|
|
||||||
import os
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +123,7 @@ def get_parser():
|
|||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/bpe.model",
|
||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -449,14 +448,14 @@ def decode_one_chunk(
|
|||||||
feature_lens = []
|
feature_lens = []
|
||||||
states = []
|
states = []
|
||||||
processed_lens = [] # Used in fast-beam-search
|
processed_lens = [] # Used in fast-beam-search
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
processed_lens.append(stream.done_frames)
|
processed_lens.append(stream.done_frames)
|
||||||
|
|
||||||
feature_lens = torch.tensor(feature_lens, device=model.device)
|
feature_lens = torch.tensor(feature_lens, device=model.device)
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||||
|
|
||||||
@ -518,9 +517,9 @@ def decode_one_chunk(
|
|||||||
decode_streams[i].states = states[i]
|
decode_streams[i].states = states[i]
|
||||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
# if decode_streams[i].done:
|
# if decode_streams[i].done:
|
||||||
# finished_streams.append(i)
|
# finished_streams.append(i)
|
||||||
finished_streams.append(i)
|
finished_streams.append(i)
|
||||||
|
|
||||||
return finished_streams
|
return finished_streams
|
||||||
|
|
||||||
|
|
||||||
@ -528,7 +527,7 @@ def decode_dataset(
|
|||||||
cuts: CutSet,
|
cuts: CutSet,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -540,7 +539,7 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
tokenizer:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
@ -608,7 +607,7 @@ def decode_dataset(
|
|||||||
(
|
(
|
||||||
decode_streams[i].id,
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
tokenizer.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del decode_streams[i]
|
del decode_streams[i]
|
||||||
@ -628,21 +627,20 @@ def decode_dataset(
|
|||||||
)
|
)
|
||||||
# print('INSIDE FOR LOOP ')
|
# print('INSIDE FOR LOOP ')
|
||||||
# print(finished_streams)
|
# print(finished_streams)
|
||||||
|
|
||||||
if not finished_streams:
|
if not finished_streams:
|
||||||
print("No finished streams, breaking the loop")
|
print("No finished streams, breaking the loop")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
try:
|
try:
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].id,
|
decode_streams[i].id,
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
tokenizer.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del decode_streams[i]
|
del decode_streams[i]
|
||||||
except IndexError as e:
|
except IndexError as e:
|
||||||
print(f"IndexError: {e}")
|
print(f"IndexError: {e}")
|
||||||
@ -650,7 +648,7 @@ def decode_dataset(
|
|||||||
print(f"finished_streams: {finished_streams}")
|
print(f"finished_streams: {finished_streams}")
|
||||||
print(f"i: {i}")
|
print(f"i: {i}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
key = "greedy_search"
|
key = "greedy_search"
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
@ -663,7 +661,7 @@ def decode_dataset(
|
|||||||
key = f"num_active_paths_{params.num_active_paths}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
@ -755,12 +753,12 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
sp_token = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp_token.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp_token.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp_token.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -854,11 +852,11 @@ def main():
|
|||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
||||||
|
|
||||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||||
test_cuts = reazonspeech_corpus.test_cuts()
|
test_cuts = reazonspeech_corpus.test_cuts()
|
||||||
|
|
||||||
@ -870,7 +868,7 @@ def main():
|
|||||||
cuts=test_cut,
|
cuts=test_cut,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
tokenizer=sp_token,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
save_results(
|
save_results(
|
||||||
@ -878,9 +876,9 @@ def main():
|
|||||||
test_set_name=test_set,
|
test_set_name=test_set,
|
||||||
results_dict=results_dict,
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
# valid_cuts = reazonspeech_corpus.valid_cuts()
|
# valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||||
|
|
||||||
# for valid_cut in valid_cuts:
|
# for valid_cut in valid_cuts:
|
||||||
# results_dict = decode_dataset(
|
# results_dict = decode_dataset(
|
||||||
# cuts=valid_cut,
|
# cuts=valid_cut,
|
||||||
@ -894,7 +892,7 @@ def main():
|
|||||||
# test_set_name="valid",
|
# test_set_name="valid",
|
||||||
# results_dict=results_dict,
|
# results_dict=results_dict,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user