Fix style

This commit is contained in:
pkufool 2022-05-29 07:48:52 +08:00
parent aecaecfb17
commit 605838da55
9 changed files with 33 additions and 31 deletions

View File

@ -380,7 +380,7 @@ def decode_one_batch(
x_lens=feature_lens, x_lens=feature_lens,
chunk_size=params.right_chunk_size, chunk_size=params.right_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True simulate_streaming=True,
) )
else: else:
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(

View File

@ -106,7 +106,7 @@ class DecodeStream(object):
self.features.size(0) - self.num_processed_frames, chunk_size + 3 self.features.size(0) - self.num_processed_frames, chunk_size + 3
) )
ret_features = self.features[ ret_features = self.features[
self.num_processed_frames : self.num_processed_frames self.num_processed_frames : self.num_processed_frames # noqa
+ ret_chunk_size, + ret_chunk_size,
:, :,
] ]

View File

@ -157,9 +157,9 @@ class Transducer(nn.Module):
else: else:
offset = (boundary[:, 3] - 1) / 2 offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2]) total_syms = torch.sum(boundary[:, 2])
offset = torch.arange( offset = torch.arange(T0, device=px_grad.device).reshape(
T0, device=px_grad.device 1, 1, T0
).reshape(1, 1, T0) - offset.reshape(B, 1, 1) ) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms sym_delay = torch.sum(sym_delay) / total_syms

View File

@ -28,35 +28,35 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np
import k2 import k2
import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from kaldifeat import FbankOptions, Fbank
from lhotse import CutSet
from train import get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -239,7 +239,7 @@ def greedy_search(
for t in range(T): for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner( logits = model.joiner(
current_encoder_out.unsqueeze(2), current_encoder_out.unsqueeze(2),
@ -494,7 +494,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
@ -512,7 +512,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),

View File

@ -104,6 +104,7 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -339,7 +340,7 @@ def decode_one_batch(
x_lens=feature_lens, x_lens=feature_lens,
chunk_size=params.right_chunk_size, chunk_size=params.right_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True simulate_streaming=True,
) )
else: else:
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(

View File

@ -28,35 +28,35 @@ Usage:
import argparse import argparse
import logging import logging
import math import math
from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np
import k2 import k2
import numpy as np
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import one_best_decoding
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from kaldifeat import FbankOptions, Fbank
from lhotse import CutSet
from train import get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -241,7 +241,7 @@ def greedy_search(
for t in range(T): for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner( logits = model.joiner(
current_encoder_out.unsqueeze(2), current_encoder_out.unsqueeze(2),
@ -501,7 +501,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),
@ -519,7 +519,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].hyp hyp = decode_streams[i].hyp
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = hyp[params.context_size :] hyp = hyp[params.context_size :] # noqa
decode_results.append( decode_results.append(
( (
decode_streams[i].ground_truth.split(), decode_streams[i].ground_truth.split(),

View File

@ -651,7 +651,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay: if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item() info["sym_delay"] = sym_delay.detach().cpu().item()

View File

@ -71,6 +71,7 @@ Usage:
import argparse import argparse
import logging import logging
import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -105,6 +106,7 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -350,7 +352,7 @@ def decode_one_batch(
x_lens=feature_lens, x_lens=feature_lens,
chunk_size=params.right_chunk_size, chunk_size=params.right_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True simulate_streaming=True,
) )
else: else:
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(

View File

@ -56,7 +56,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
""" """
import argparse import argparse
import copy import copy
import logging import logging