mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 10:04:21 +00:00
Fix style
This commit is contained in:
parent
aecaecfb17
commit
605838da55
@ -380,7 +380,7 @@ def decode_one_batch(
|
|||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.right_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
@ -106,7 +106,7 @@ class DecodeStream(object):
|
|||||||
self.features.size(0) - self.num_processed_frames, chunk_size + 3
|
self.features.size(0) - self.num_processed_frames, chunk_size + 3
|
||||||
)
|
)
|
||||||
ret_features = self.features[
|
ret_features = self.features[
|
||||||
self.num_processed_frames : self.num_processed_frames
|
self.num_processed_frames : self.num_processed_frames # noqa
|
||||||
+ ret_chunk_size,
|
+ ret_chunk_size,
|
||||||
:,
|
:,
|
||||||
]
|
]
|
||||||
|
@ -157,9 +157,9 @@ class Transducer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
offset = (boundary[:, 3] - 1) / 2
|
offset = (boundary[:, 3] - 1) / 2
|
||||||
total_syms = torch.sum(boundary[:, 2])
|
total_syms = torch.sum(boundary[:, 2])
|
||||||
offset = torch.arange(
|
offset = torch.arange(T0, device=px_grad.device).reshape(
|
||||||
T0, device=px_grad.device
|
1, 1, T0
|
||||||
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
|
) - offset.reshape(B, 1, 1)
|
||||||
sym_delay = px_grad * offset
|
sym_delay = px_grad * offset
|
||||||
sym_delay = torch.sum(sym_delay) / total_syms
|
sym_delay = torch.sum(sym_delay) / total_syms
|
||||||
|
|
||||||
|
@ -28,35 +28,35 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
|
from kaldifeat import Fbank, FbankOptions
|
||||||
|
from lhotse import CutSet
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.decode import one_best_decoding
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
|
||||||
from icefall.utils import get_texts
|
|
||||||
from kaldifeat import FbankOptions, Fbank
|
|
||||||
from lhotse import CutSet
|
|
||||||
from train import get_params, get_transducer_model
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -239,7 +239,7 @@ def greedy_search(
|
|||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :]
|
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||||
|
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out.unsqueeze(2),
|
current_encoder_out.unsqueeze(2),
|
||||||
@ -494,7 +494,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
hyp = decode_streams[i].hyp
|
hyp = decode_streams[i].hyp
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = hyp[params.context_size :]
|
hyp = hyp[params.context_size :] # noqa
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
@ -512,7 +512,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
hyp = decode_streams[i].hyp
|
hyp = decode_streams[i].hyp
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = hyp[params.context_size :]
|
hyp = hyp[params.context_size :] # noqa
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
|
@ -104,6 +104,7 @@ from icefall.utils import (
|
|||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -339,7 +340,7 @@ def decode_one_batch(
|
|||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.right_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
@ -28,35 +28,35 @@ Usage:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
|
from kaldifeat import Fbank, FbankOptions
|
||||||
|
from lhotse import CutSet
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.decode import one_best_decoding
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
|
||||||
from icefall.utils import get_texts
|
|
||||||
from kaldifeat import FbankOptions, Fbank
|
|
||||||
from lhotse import CutSet
|
|
||||||
from train import get_params, get_transducer_model
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -241,7 +241,7 @@ def greedy_search(
|
|||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :]
|
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||||
|
|
||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out.unsqueeze(2),
|
current_encoder_out.unsqueeze(2),
|
||||||
@ -501,7 +501,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
hyp = decode_streams[i].hyp
|
hyp = decode_streams[i].hyp
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = hyp[params.context_size :]
|
hyp = hyp[params.context_size :] # noqa
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
@ -519,7 +519,7 @@ def decode_dataset(
|
|||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
hyp = decode_streams[i].hyp
|
hyp = decode_streams[i].hyp
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = hyp[params.context_size :]
|
hyp = hyp[params.context_size :] # noqa
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].ground_truth.split(),
|
decode_streams[i].ground_truth.split(),
|
||||||
|
@ -651,7 +651,7 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
if params.return_sym_delay:
|
if params.return_sym_delay:
|
||||||
info["sym_delay"] = sym_delay.detach().cpu().item()
|
info["sym_delay"] = sym_delay.detach().cpu().item()
|
||||||
|
|
||||||
|
@ -71,6 +71,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -105,6 +106,7 @@ from icefall.utils import (
|
|||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -350,7 +352,7 @@ def decode_one_batch(
|
|||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.right_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
@ -56,7 +56,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
Loading…
x
Reference in New Issue
Block a user