diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index cdb8b5d1b..b08542950 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -380,7 +380,7 @@ def decode_one_batch( x_lens=feature_lens, chunk_size=params.right_chunk_size, left_context=params.left_context, - simulate_streaming=True + simulate_streaming=True, ) else: encoder_out, encoder_out_lens = model.encoder( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index 5afe4f859..fd1be9cdf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -106,7 +106,7 @@ class DecodeStream(object): self.features.size(0) - self.num_processed_frames, chunk_size + 3 ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + self.num_processed_frames : self.num_processed_frames # noqa + ret_chunk_size, :, ] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 8b50ac657..1c83adc44 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -157,9 +157,9 @@ class Transducer(nn.Module): else: offset = (boundary[:, 3] - 1) / 2 total_syms = torch.sum(boundary[:, 2]) - offset = torch.arange( - T0, device=px_grad.device - ).reshape(1, 1, T0) - offset.reshape(B, 1, 1) + offset = torch.arange(T0, device=px_grad.device).reshape( + 1, 1, T0 + ) - offset.reshape(B, 1, 1) sym_delay = px_grad * offset sym_delay = torch.sum(sym_delay) / total_syms diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index f4a68e958..62f602abc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -28,35 +28,35 @@ Usage: import argparse import logging import math -from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import numpy as np import k2 +import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, find_checkpoints, load_checkpoint, ) +from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, + get_texts, setup_logger, store_transcripts, str2bool, write_error_stats, ) -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts -from kaldifeat import FbankOptions, Fbank -from lhotse import CutSet -from train import get_params, get_transducer_model -from torch.nn.utils.rnn import pad_sequence LOG_EPS = math.log(1e-10) @@ -239,7 +239,7 @@ def greedy_search( for t in range(T): # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa logits = model.joiner( current_encoder_out.unsqueeze(2), @@ -494,7 +494,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): hyp = decode_streams[i].hyp if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] + hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), @@ -512,7 +512,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): hyp = decode_streams[i].hyp if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] + hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 37edf1ff9..e812fb534 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -104,6 +104,7 @@ from icefall.utils import ( LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -339,7 +340,7 @@ def decode_one_batch( x_lens=feature_lens, chunk_size=params.right_chunk_size, left_context=params.left_context, - simulate_streaming=True + simulate_streaming=True, ) else: encoder_out, encoder_out_lens = model.encoder( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index c95b22c0e..ced5e2a3b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -28,35 +28,35 @@ Usage: import argparse import logging import math -from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import numpy as np import k2 +import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, find_checkpoints, load_checkpoint, ) +from icefall.decode import one_best_decoding from icefall.utils import ( AttributeDict, + get_texts, setup_logger, store_transcripts, str2bool, write_error_stats, ) -from icefall.decode import Nbest, one_best_decoding -from icefall.utils import get_texts -from kaldifeat import FbankOptions, Fbank -from lhotse import CutSet -from train import get_params, get_transducer_model -from torch.nn.utils.rnn import pad_sequence LOG_EPS = math.log(1e-10) @@ -241,7 +241,7 @@ def greedy_search( for t in range(T): # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa logits = model.joiner( current_encoder_out.unsqueeze(2), @@ -501,7 +501,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): hyp = decode_streams[i].hyp if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] + hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), @@ -519,7 +519,7 @@ def decode_dataset( for i in sorted(finished_streams, reverse=True): hyp = decode_streams[i].hyp if params.decoding_method == "greedy_search": - hyp = hyp[params.context_size :] + hyp = hyp[params.context_size :] # noqa decode_results.append( ( decode_streams[i].ground_truth.split(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 3bab3ba98..cee271706 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -651,7 +651,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - + if params.return_sym_delay: info["sym_delay"] = sym_delay.detach().cpu().item() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5f5b01f59..db8a80f46 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -71,6 +71,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -105,6 +106,7 @@ from icefall.utils import ( LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -350,7 +352,7 @@ def decode_one_batch( x_lens=feature_lens, chunk_size=params.right_chunk_size, left_context=params.left_context, - simulate_streaming=True + simulate_streaming=True, ) else: encoder_out, encoder_out_lens = model.encoder( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index c73e7777b..2a46d2c18 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -56,7 +56,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" """ - import argparse import copy import logging