Reformatted streaming_decode.py with flake8

This commit is contained in:
Bailey Hirota 2025-01-15 01:11:29 +09:00
parent b574e68bf4
commit 9ab3021640

View File

@ -22,13 +22,14 @@ Usage:
""" """
import pdb
import argparse import argparse
import logging import logging
import math import math
import os
import pdb
import subprocess as sp
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer
import k2 import k2
import numpy as np import numpy as np
@ -42,6 +43,7 @@ from streaming_beam_search import (
greedy_search, greedy_search,
modified_beam_search, modified_beam_search,
) )
from tokenizer import Tokenizer
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -61,9 +63,6 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
import subprocess as sp
import os
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -124,7 +123,7 @@ def get_parser():
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
@ -449,14 +448,14 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
processed_lens = [] # Used in fast-beam-search processed_lens = [] # Used in fast-beam-search
for stream in decode_streams: for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2) feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=model.device) feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -518,9 +517,9 @@ def decode_one_chunk(
decode_streams[i].states = states[i] decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done: # if decode_streams[i].done:
# finished_streams.append(i) # finished_streams.append(i)
finished_streams.append(i) finished_streams.append(i)
return finished_streams return finished_streams
@ -528,7 +527,7 @@ def decode_dataset(
cuts: CutSet, cuts: CutSet,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: Tokenizer, tokenizer: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -540,7 +539,7 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
sp: tokenizer:
The BPE model. The BPE model.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
@ -608,7 +607,7 @@ def decode_dataset(
( (
decode_streams[i].id, decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), tokenizer.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -628,21 +627,20 @@ def decode_dataset(
) )
# print('INSIDE FOR LOOP ') # print('INSIDE FOR LOOP ')
# print(finished_streams) # print(finished_streams)
if not finished_streams: if not finished_streams:
print("No finished streams, breaking the loop") print("No finished streams, breaking the loop")
break break
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
try: try:
decode_results.append( decode_results.append(
( (
decode_streams[i].id, decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), tokenizer.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
except IndexError as e: except IndexError as e:
print(f"IndexError: {e}") print(f"IndexError: {e}")
@ -650,7 +648,7 @@ def decode_dataset(
print(f"finished_streams: {finished_streams}") print(f"finished_streams: {finished_streams}")
print(f"i: {i}") print(f"i: {i}")
continue continue
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
key = "greedy_search" key = "greedy_search"
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
@ -663,7 +661,7 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}") raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize() torch.cuda.synchronize()
return {key: decode_results} return {key: decode_results}
@ -755,12 +753,12 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type) sp_token = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp_token.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp_token.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp_token.get_piece_size()
logging.info(params) logging.info(params)
@ -854,11 +852,11 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
valid_cuts = reazonspeech_corpus.valid_cuts() valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts() test_cuts = reazonspeech_corpus.test_cuts()
@ -870,7 +868,7 @@ def main():
cuts=test_cut, cuts=test_cut,
params=params, params=params,
model=model, model=model,
sp=sp, tokenizer=sp_token,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
save_results( save_results(
@ -878,9 +876,9 @@ def main():
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
# valid_cuts = reazonspeech_corpus.valid_cuts() # valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts: # for valid_cut in valid_cuts:
# results_dict = decode_dataset( # results_dict = decode_dataset(
# cuts=valid_cut, # cuts=valid_cut,
@ -894,7 +892,7 @@ def main():
# test_set_name="valid", # test_set_name="valid",
# results_dict=results_dict, # results_dict=results_dict,
# ) # )
logging.info("Done!") logging.info("Done!")