mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
fix formatting issues
This commit is contained in:
parent
563292599b
commit
814d3ac702
@ -22,13 +22,15 @@ Usage:
|
||||
|
||||
"""
|
||||
|
||||
import pdb
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pdb
|
||||
|
||||
# import subprocess as sp
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
@ -42,6 +44,7 @@ from streaming_beam_search import (
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
@ -61,9 +64,6 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
import subprocess as sp
|
||||
import os
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
@ -124,7 +124,7 @@ def get_parser():
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
@ -449,14 +449,14 @@ def decode_one_chunk(
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = [] # Used in fast-beam-search
|
||||
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=model.device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
@ -518,9 +518,9 @@ def decode_one_chunk(
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
# if decode_streams[i].done:
|
||||
# finished_streams.append(i)
|
||||
# finished_streams.append(i)
|
||||
finished_streams.append(i)
|
||||
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
@ -628,21 +628,20 @@ def decode_dataset(
|
||||
)
|
||||
# print('INSIDE FOR LOOP ')
|
||||
# print(finished_streams)
|
||||
|
||||
|
||||
if not finished_streams:
|
||||
print("No finished streams, breaking the loop")
|
||||
break
|
||||
|
||||
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
try:
|
||||
try:
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
except IndexError as e:
|
||||
print(f"IndexError: {e}")
|
||||
@ -650,7 +649,7 @@ def decode_dataset(
|
||||
print(f"finished_streams: {finished_streams}")
|
||||
print(f"i: {i}")
|
||||
continue
|
||||
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
@ -663,7 +662,7 @@ def decode_dataset(
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
@ -854,11 +853,11 @@ def main():
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
||||
|
||||
|
||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||
test_cuts = reazonspeech_corpus.test_cuts()
|
||||
|
||||
@ -878,9 +877,9 @@ def main():
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
|
||||
# valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||
|
||||
|
||||
# for valid_cut in valid_cuts:
|
||||
# results_dict = decode_dataset(
|
||||
# cuts=valid_cut,
|
||||
@ -894,7 +893,7 @@ def main():
|
||||
# test_set_name="valid",
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user