fix formatting issues

This commit is contained in:
root 2024-08-14 15:29:29 +09:00
parent 563292599b
commit 814d3ac702

View File

@ -22,13 +22,15 @@ Usage:
""" """
import pdb
import argparse import argparse
import logging import logging
import math import math
import os
import pdb
# import subprocess as sp
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer
import k2 import k2
import numpy as np import numpy as np
@ -42,6 +44,7 @@ from streaming_beam_search import (
greedy_search, greedy_search,
modified_beam_search, modified_beam_search,
) )
from tokenizer import Tokenizer
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -61,9 +64,6 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
import subprocess as sp
import os
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -124,7 +124,7 @@ def get_parser():
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
@ -449,14 +449,14 @@ def decode_one_chunk(
feature_lens = [] feature_lens = []
states = [] states = []
processed_lens = [] # Used in fast-beam-search processed_lens = [] # Used in fast-beam-search
for stream in decode_streams: for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2) feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=model.device) feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -518,9 +518,9 @@ def decode_one_chunk(
decode_streams[i].states = states[i] decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done: # if decode_streams[i].done:
# finished_streams.append(i) # finished_streams.append(i)
finished_streams.append(i) finished_streams.append(i)
return finished_streams return finished_streams
@ -628,21 +628,20 @@ def decode_dataset(
) )
# print('INSIDE FOR LOOP ') # print('INSIDE FOR LOOP ')
# print(finished_streams) # print(finished_streams)
if not finished_streams: if not finished_streams:
print("No finished streams, breaking the loop") print("No finished streams, breaking the loop")
break break
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
try: try:
decode_results.append( decode_results.append(
( (
decode_streams[i].id, decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
except IndexError as e: except IndexError as e:
print(f"IndexError: {e}") print(f"IndexError: {e}")
@ -650,7 +649,7 @@ def decode_dataset(
print(f"finished_streams: {finished_streams}") print(f"finished_streams: {finished_streams}")
print(f"i: {i}") print(f"i: {i}")
continue continue
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
key = "greedy_search" key = "greedy_search"
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
@ -663,7 +662,7 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}") raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize() torch.cuda.synchronize()
return {key: decode_results} return {key: decode_results}
@ -854,11 +853,11 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
valid_cuts = reazonspeech_corpus.valid_cuts() valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts() test_cuts = reazonspeech_corpus.test_cuts()
@ -878,9 +877,9 @@ def main():
test_set_name=test_set, test_set_name=test_set,
results_dict=results_dict, results_dict=results_dict,
) )
# valid_cuts = reazonspeech_corpus.valid_cuts() # valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts: # for valid_cut in valid_cuts:
# results_dict = decode_dataset( # results_dict = decode_dataset(
# cuts=valid_cut, # cuts=valid_cut,
@ -894,7 +893,7 @@ def main():
# test_set_name="valid", # test_set_name="valid",
# results_dict=results_dict, # results_dict=results_dict,
# ) # )
logging.info("Done!") logging.info("Done!")