mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 01:54:20 +00:00
Fix style
This commit is contained in:
parent
aecaecfb17
commit
605838da55
@ -380,7 +380,7 @@ def decode_one_batch(
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
|
@ -106,7 +106,7 @@ class DecodeStream(object):
|
||||
self.features.size(0) - self.num_processed_frames, chunk_size + 3
|
||||
)
|
||||
ret_features = self.features[
|
||||
self.num_processed_frames : self.num_processed_frames
|
||||
self.num_processed_frames : self.num_processed_frames # noqa
|
||||
+ ret_chunk_size,
|
||||
:,
|
||||
]
|
||||
|
@ -157,9 +157,9 @@ class Transducer(nn.Module):
|
||||
else:
|
||||
offset = (boundary[:, 3] - 1) / 2
|
||||
total_syms = torch.sum(boundary[:, 2])
|
||||
offset = torch.arange(
|
||||
T0, device=px_grad.device
|
||||
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
|
||||
offset = torch.arange(T0, device=px_grad.device).reshape(
|
||||
1, 1, T0
|
||||
) - offset.reshape(B, 1, 1)
|
||||
sym_delay = px_grad * offset
|
||||
sym_delay = torch.sum(sym_delay) / total_syms
|
||||
|
||||
|
@ -28,35 +28,35 @@ Usage:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
from kaldifeat import FbankOptions, Fbank
|
||||
from lhotse import CutSet
|
||||
from train import get_params, get_transducer_model
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -239,7 +239,7 @@ def greedy_search(
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :]
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
@ -494,7 +494,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :]
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
@ -512,7 +512,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :]
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
|
@ -104,6 +104,7 @@ from icefall.utils import (
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -339,7 +340,7 @@ def decode_one_batch(
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
|
@ -28,35 +28,35 @@ Usage:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.utils import get_texts
|
||||
from kaldifeat import FbankOptions, Fbank
|
||||
from lhotse import CutSet
|
||||
from train import get_params, get_transducer_model
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -241,7 +241,7 @@ def greedy_search(
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :]
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
@ -501,7 +501,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :]
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
@ -519,7 +519,7 @@ def decode_dataset(
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :]
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
|
@ -651,7 +651,7 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
|
||||
|
||||
if params.return_sym_delay:
|
||||
info["sym_delay"] = sym_delay.detach().cpu().item()
|
||||
|
||||
|
@ -71,6 +71,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -105,6 +106,7 @@ from icefall.utils import (
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -350,7 +352,7 @@ def decode_one_batch(
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
|
@ -56,7 +56,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
|
Loading…
x
Reference in New Issue
Block a user