mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Reformatted streaming_decode.py with flake8
This commit is contained in:
parent
b574e68bf4
commit
9ab3021640
@ -22,13 +22,14 @@ Usage:
|
||||
|
||||
"""
|
||||
|
||||
import pdb
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pdb
|
||||
import subprocess as sp
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
@ -42,6 +43,7 @@ from streaming_beam_search import (
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
@ -61,9 +63,6 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
import subprocess as sp
|
||||
import os
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
@ -528,7 +527,7 @@ def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: Tokenizer,
|
||||
tokenizer: Tokenizer,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
@ -540,7 +539,7 @@ def decode_dataset(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
tokenizer:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
@ -608,7 +607,7 @@ def decode_dataset(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
tokenizer.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
@ -633,14 +632,13 @@ def decode_dataset(
|
||||
print("No finished streams, breaking the loop")
|
||||
break
|
||||
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
try:
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
tokenizer.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
@ -755,12 +753,12 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
sp_token = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
params.blank_id = sp_token.piece_to_id("<blk>")
|
||||
params.unk_id = sp_token.piece_to_id("<unk>")
|
||||
params.vocab_size = sp_token.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -870,7 +868,7 @@ def main():
|
||||
cuts=test_cut,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
tokenizer=sp_token,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
save_results(
|
||||
|
Loading…
x
Reference in New Issue
Block a user