fix formatting issues

This commit is contained in:
root 2024-08-14 15:29:29 +09:00
parent 563292599b
commit 814d3ac702

View File

@ -22,13 +22,15 @@ Usage:
"""
import pdb
import argparse
import logging
import math
import os
import pdb
# import subprocess as sp
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer
import k2
import numpy as np
@ -42,6 +44,7 @@ from streaming_beam_search import (
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
@ -61,9 +64,6 @@ from icefall.utils import (
write_error_stats,
)
import subprocess as sp
import os
LOG_EPS = math.log(1e-10)
@ -124,7 +124,7 @@ def get_parser():
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
@ -449,14 +449,14 @@ def decode_one_chunk(
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
@ -518,9 +518,9 @@ def decode_one_chunk(
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done:
# finished_streams.append(i)
# finished_streams.append(i)
finished_streams.append(i)
return finished_streams
@ -628,21 +628,20 @@ def decode_dataset(
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)
if not finished_streams:
print("No finished streams, breaking the loop")
break
for i in sorted(finished_streams, reverse=True):
try:
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
)
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
@ -650,7 +649,7 @@ def decode_dataset(
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
@ -663,7 +662,7 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
torch.cuda.synchronize()
return {key: decode_results}
@ -854,11 +853,11 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
@ -878,9 +877,9 @@ def main():
test_set_name=test_set,
results_dict=results_dict,
)
# valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
@ -894,7 +893,7 @@ def main():
# test_set_name="valid",
# results_dict=results_dict,
# )
logging.info("Done!")