mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Minor refinements for some stale but recently merged PRs (#1354)
* incorporate https://github.com/k2-fsa/icefall/pull/1269 * incorporate https://github.com/k2-fsa/icefall/pull/1301 * black formatted * incorporate https://github.com/k2-fsa/icefall/pull/1162 * black formatted
This commit is contained in:
parent
c970df512b
commit
23913f6afd
@ -1128,7 +1128,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
@ -1164,7 +1164,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
):
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in tuples:
|
||||
for p, state, param_names in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
for p, state, batch_param_names in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||
):
|
||||
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
|
||||
# if epoch == 130:
|
||||
# opts = diagnostics.TensorDiagnosticOptions(
|
||||
# 2 ** 22
|
||||
# 512
|
||||
# ) # allow 4 megabytes per sub-module
|
||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||
|
||||
|
@ -1194,7 +1194,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -1565,7 +1565,7 @@ def run(rank, world_size, args):
|
||||
if params.print_diagnostics:
|
||||
args.max_duration = 100
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
@ -307,8 +307,8 @@ class LibriSpeechAsrDataModule:
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
|
@ -22,10 +22,11 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import pprint
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pprint
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module):
|
||||
return y
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
assert len(x.shape) == 3, "Input is not a 3D tensor!"
|
||||
y = self.exponential_moving_avg(x)
|
||||
y = y.permute(0, 2, 1) # make channel last for squeeze op
|
||||
|
@ -76,17 +76,17 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import k2
|
||||
import torch
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import str2bool
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -143,13 +143,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -189,17 +186,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if "lang_bpe" in str(params.lang_dir):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.lang_dir + "/bpe.model")
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
else:
|
||||
assert "lang_phone" in str(params.lang_dir)
|
||||
phone_lexicon = UniqLexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(phone_lexicon.tokens) + 1
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
@ -405,7 +404,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
encoder = Conv1dNet(
|
||||
output_dim=params.encoder_dim,
|
||||
input_dim=params.feature_dim,
|
||||
@ -1043,7 +1041,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -23,6 +23,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_ctc_model, get_params
|
||||
@ -33,8 +34,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
from icefall.utils import num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -90,11 +90,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="""It contains language related input files such as "lexicon.txt"
|
||||
""",
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -113,17 +112,15 @@ def get_parser():
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(params)
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
params.vocab_size = num_classes
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
|
@ -947,7 +947,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -244,16 +244,14 @@ class TensorDiagnostic(object):
|
||||
|
||||
if stats_type == "eigs":
|
||||
try:
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
||||
eigs, _ = torch.linalg.eigh(stats)
|
||||
else:
|
||||
eigs, _ = torch.symeig(stats)
|
||||
stats = eigs.abs().sqrt()
|
||||
except: # noqa
|
||||
print(
|
||||
"Error getting eigenvalues, trying another method."
|
||||
)
|
||||
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
|
||||
print("Error getting eigenvalues, trying another method.")
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
|
||||
eigs, _ = torch.linalg.eig(stats)
|
||||
eigs = eigs.abs()
|
||||
else:
|
||||
@ -579,10 +577,15 @@ def attach_diagnostics(
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
|
||||
if isinstance(o, Tensor) and o.dtype in (
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.float64,
|
||||
):
|
||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||
if isinstance(_output, tuple) and len(_output) == 1:
|
||||
_output = _output[0]
|
||||
@ -596,9 +599,15 @@ def attach_diagnostics(
|
||||
)
|
||||
elif isinstance(_output, tuple):
|
||||
for i, o in enumerate(_output):
|
||||
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
||||
class_name=get_class_name(_module))
|
||||
if isinstance(o, Tensor) and o.dtype in (
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.float64,
|
||||
):
|
||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
||||
o, class_name=get_class_name(_module)
|
||||
)
|
||||
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_backward_hook(backward_hook)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user