Fix style

This commit is contained in:
pkufool 2022-05-29 07:48:52 +08:00
parent aecaecfb17
commit 605838da55
9 changed files with 33 additions and 31 deletions

View File

@ -380,7 +380,7 @@ def decode_one_batch(
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(

View File

@ -106,7 +106,7 @@ class DecodeStream(object):
self.features.size(0) - self.num_processed_frames, chunk_size + 3
)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames
self.num_processed_frames : self.num_processed_frames # noqa
+ ret_chunk_size,
:,
]

View File

@ -157,9 +157,9 @@ class Transducer(nn.Module):
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(
T0, device=px_grad.device
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
offset = torch.arange(T0, device=px_grad.device).reshape(
1, 1, T0
) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms

View File

@ -28,35 +28,35 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from kaldifeat import FbankOptions, Fbank
from lhotse import CutSet
from train import get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
LOG_EPS = math.log(1e-10)
@ -239,7 +239,7 @@ def greedy_search(
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :]
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
@ -494,7 +494,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
@ -512,7 +512,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),

View File

@ -104,6 +104,7 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -339,7 +340,7 @@ def decode_one_batch(
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(

View File

@ -28,35 +28,35 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from kaldifeat import FbankOptions, Fbank
from lhotse import CutSet
from train import get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
LOG_EPS = math.log(1e-10)
@ -241,7 +241,7 @@ def greedy_search(
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :]
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
@ -501,7 +501,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),
@ -519,7 +519,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :]
hyp = hyp[params.context_size :] # noqa
decode_results.append(
(
decode_streams[i].ground_truth.split(),

View File

@ -651,7 +651,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()

View File

@ -71,6 +71,7 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -105,6 +106,7 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -350,7 +352,7 @@ def decode_one_batch(
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(

View File

@ -56,7 +56,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
"""
import argparse
import copy
import logging