Reformatted streaming_decode.py with flake8

This commit is contained in:
Bailey Hirota 2025-01-15 01:11:29 +09:00
parent b574e68bf4
commit 9ab3021640

View File

@ -22,13 +22,14 @@ Usage:
""" """
import pdb
import argparse import argparse
import logging import logging
import math import math
import os
import pdb
import subprocess as sp
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer
import k2 import k2
import numpy as np import numpy as np
@ -42,6 +43,7 @@ from streaming_beam_search import (
greedy_search, greedy_search,
modified_beam_search, modified_beam_search,
) )
from tokenizer import Tokenizer
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -61,9 +63,6 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
import subprocess as sp
import os
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -518,7 +517,7 @@ def decode_one_chunk(
decode_streams[i].states = states[i] decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done: # if decode_streams[i].done:
# finished_streams.append(i) # finished_streams.append(i)
finished_streams.append(i) finished_streams.append(i)
return finished_streams return finished_streams
@ -528,7 +527,7 @@ def decode_dataset(
cuts: CutSet, cuts: CutSet,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: Tokenizer, tokenizer: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -540,7 +539,7 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
sp: tokenizer:
The BPE model. The BPE model.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
@ -608,7 +607,7 @@ def decode_dataset(
( (
decode_streams[i].id, decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), tokenizer.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -633,14 +632,13 @@ def decode_dataset(
print("No finished streams, breaking the loop") print("No finished streams, breaking the loop")
break break
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
try: try:
decode_results.append( decode_results.append(
( (
decode_streams[i].id, decode_streams[i].id,
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(), tokenizer.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -755,12 +753,12 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type) sp_token = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp_token.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp_token.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp_token.get_piece_size()
logging.info(params) logging.info(params)
@ -870,7 +868,7 @@ def main():
cuts=test_cut, cuts=test_cut,
params=params, params=params,
model=model, model=model,
sp=sp, tokenizer=sp_token,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
save_results( save_results(