Minor refinements for some stale but recently merged PRs (#1354)

* incorporate https://github.com/k2-fsa/icefall/pull/1269

* incorporate https://github.com/k2-fsa/icefall/pull/1301

* black formatted

* incorporate https://github.com/k2-fsa/icefall/pull/1162

* black formatted
This commit is contained in:
zr_jin 2023-10-31 10:28:20 +08:00 committed by GitHub
parent c970df512b
commit 23913f6afd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 57 additions and 68 deletions

View File

@ -1128,7 +1128,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")

View File

@ -1164,7 +1164,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
size_update_period=4,
clipping_update_period=100,
):
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer):
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int):
# if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# 512
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)

View File

@ -1194,7 +1194,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1565,7 +1565,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -307,8 +307,8 @@ class LibriSpeechAsrDataModule:
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,

View File

@ -22,10 +22,11 @@
import argparse
import logging
import math
import pprint
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pprint
import k2
import sentencepiece as spm
import torch

View File

@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module):
return y
def forward(self, x: Tensor) -> Tensor:
assert len(x.shape) == 3, "Input is not a 3D tensor!"
y = self.exponential_moving_avg(x)
y = y.permute(0, 2, 1) # make channel last for squeeze op

View File

@ -76,17 +76,17 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon
from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -143,13 +143,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -189,17 +186,9 @@ def main():
logging.info(f"device: {device}")
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
else:
assert "lang_phone" in str(params.lang_dir)
phone_lexicon = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(phone_lexicon.tokens) + 1
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-dim",
type=int,
@ -405,7 +404,6 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conv1dNet(
output_dim=params.encoder_dim,
input_dim=params.feature_dim,
@ -1043,7 +1041,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -23,6 +23,7 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
@ -33,8 +34,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -90,11 +90,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_bpe_500",
help="""It contains language related input files such as "lexicon.txt"
""",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -113,17 +112,15 @@ def get_parser():
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
logging.info(params)
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():

View File

@ -947,7 +947,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -244,16 +244,14 @@ class TensorDiagnostic(object):
if stats_type == "eigs":
try:
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'):
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
eigs, _ = torch.linalg.eigh(stats)
else:
eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print(
"Error getting eigenvalues, trying another method."
)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
print("Error getting eigenvalues, trying another method.")
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
eigs, _ = torch.linalg.eig(stats)
eigs = eigs.abs()
else:
@ -579,10 +577,15 @@ def attach_diagnostics(
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
class_name=get_class_name(_module))
if isinstance(o, Tensor) and o.dtype in (
torch.float32,
torch.float16,
torch.float64,
):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0]
@ -596,9 +599,15 @@ def attach_diagnostics(
)
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
class_name=get_class_name(_module))
if isinstance(o, Tensor) and o.dtype in (
torch.float32,
torch.float16,
torch.float64,
):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)