Reformatted streaming_decode.py with flake8

This commit is contained in:
Bailey Hirota 2025-01-15 01:11:29 +09:00
parent b574e68bf4
commit 9ab3021640

View File

@ -22,13 +22,14 @@ Usage:
"""
import pdb
import argparse
import logging
import math
import os
import pdb
import subprocess as sp
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer
import k2
import numpy as np
@ -42,6 +43,7 @@ from streaming_beam_search import (
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
@ -61,9 +63,6 @@ from icefall.utils import (
write_error_stats,
)
import subprocess as sp
import os
LOG_EPS = math.log(1e-10)
@ -124,7 +123,7 @@ def get_parser():
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
@ -449,14 +448,14 @@ def decode_one_chunk(
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -518,9 +517,9 @@ def decode_one_chunk(
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done:
# finished_streams.append(i)
# finished_streams.append(i)
finished_streams.append(i)
return finished_streams
@ -528,7 +527,7 @@ def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: Tokenizer,
tokenizer: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -540,7 +539,7 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
sp:
tokenizer:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
@ -608,7 +607,7 @@ def decode_dataset(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
tokenizer.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
@ -628,21 +627,20 @@ def decode_dataset(
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)
if not finished_streams:
print("No finished streams, breaking the loop")
break
for i in sorted(finished_streams, reverse=True):
try:
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
tokenizer.decode(decode_streams[i].decoding_result()).split(),
)
)
)
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
@ -650,7 +648,7 @@ def decode_dataset(
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
@ -663,7 +661,7 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
torch.cuda.synchronize()
return {key: decode_results}
@ -755,12 +753,12 @@ def main():
logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
sp_token = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = sp_token.piece_to_id("<blk>")
params.unk_id = sp_token.piece_to_id("<unk>")
params.vocab_size = sp_token.get_piece_size()
logging.info(params)
@ -854,11 +852,11 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
@ -870,7 +868,7 @@ def main():
cuts=test_cut,
params=params,
model=model,
sp=sp,
tokenizer=sp_token,
decoding_graph=decoding_graph,
)
save_results(
@ -878,9 +876,9 @@ def main():
test_set_name=test_set,
results_dict=results_dict,
)
# valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
@ -894,7 +892,7 @@ def main():
# test_set_name="valid",
# results_dict=results_dict,
# )
logging.info("Done!")