Minor refinements for some stale but recently merged PRs (#1354)

* incorporate https://github.com/k2-fsa/icefall/pull/1269

* incorporate https://github.com/k2-fsa/icefall/pull/1301

* black formatted

* incorporate https://github.com/k2-fsa/icefall/pull/1162

* black formatted
This commit is contained in:
zr_jin 2023-10-31 10:28:20 +08:00 committed by GitHub
parent c970df512b
commit 23913f6afd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 57 additions and 68 deletions

View File

@ -1128,7 +1128,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")

View File

@ -1164,7 +1164,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
yield tuples # <-- calling code will do the actual optimization here! yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i]) p.copy_(stacked_params[i])
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
size_update_period=4, size_update_period=4,
clipping_update_period=100, clipping_update_period=100,
): ):
defaults = dict( defaults = dict(
lr=lr, lr=lr,
clipping_scale=clipping_scale, clipping_scale=clipping_scale,
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
batch = True batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names): for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches: with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like # batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to # a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim. # a stacking dim, it is not a real dim.
@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"] clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device) tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples: for p, state, param_names in tuples:
grad = p.grad grad = p.grad
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError( raise RuntimeError(
@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time. from tuples, we still pass it to save some time.
""" """
all_sumsq_orig = {} all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples: for p, state, batch_param_names in tuples:
# p is a stacked batch parameters. # p is a stacked batch parameters.
batch_grad = p.grad batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars if p.numel() == p.shape[0]: # a batch of scalars
@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer):
for name, sumsq_orig, rms, grad in zip( for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
): ):
proportion_orig = sumsq_orig / tot_sumsq proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int):
# if epoch == 130: # if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions( # opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22 # 512
# ) # allow 4 megabytes per sub-module # ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts) # diagnostic = diagnostics.attach_diagnostics(m, opts)

View File

@ -1194,7 +1194,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1565,7 +1565,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
args.max_duration = 100 args.max_duration = 100
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -307,8 +307,8 @@ class LibriSpeechAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,

View File

@ -22,10 +22,11 @@
import argparse import argparse
import logging import logging
import math import math
import pprint
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import pprint
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch

View File

@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module):
return y return y
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert len(x.shape) == 3, "Input is not a 3D tensor!" assert len(x.shape) == 3, "Input is not a 3D tensor!"
y = self.exponential_moving_avg(x) y = self.exponential_moving_avg(x)
y = y.permute(0, 2, 1) # make channel last for squeeze op y = y.permute(0, 2, 1) # make channel last for squeeze op

View File

@ -76,17 +76,17 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import UniqLexicon from icefall.utils import num_tokens, str2bool
from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model
def get_parser(): def get_parser():
@ -143,13 +143,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_bpe_500", default="data/lang_bpe_500/tokens.txt",
help="""The lang dir help="Path to the tokens.txt",
It contains language related input files such as
"lexicon.txt"
""",
) )
parser.add_argument( parser.add_argument(
@ -189,17 +186,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
if "lang_bpe" in str(params.lang_dir): token_table = k2.SymbolTable.from_file(params.tokens)
sp = spm.SentencePieceProcessor() params.blank_id = token_table["<blk>"]
sp.load(params.lang_dir + "/bpe.model") params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
else:
assert "lang_phone" in str(params.lang_dir)
phone_lexicon = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(phone_lexicon.tokens) + 1
logging.info(params) logging.info(params)

View File

@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=int, type=int,
@ -405,7 +404,6 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conv1dNet( encoder = Conv1dNet(
output_dim=params.encoder_dim, output_dim=params.encoder_dim,
input_dim=params.feature_dim, input_dim=params.feature_dim,
@ -1043,7 +1041,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -23,6 +23,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params from train import add_model_arguments, get_ctc_model, get_params
@ -33,8 +34,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, str2bool
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -90,11 +90,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_bpe_500", default="data/lang_bpe_500/tokens.txt",
help="""It contains language related input files such as "lexicon.txt" help="Path to the tokens.txt",
""",
) )
parser.add_argument( parser.add_argument(
@ -113,17 +112,15 @@ def get_parser():
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
logging.info(params) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
lexicon = Lexicon(params.lang_dir) logging.info(params)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
params.vocab_size = num_classes
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -947,7 +947,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 512
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -244,16 +244,14 @@ class TensorDiagnostic(object):
if stats_type == "eigs": if stats_type == "eigs":
try: try:
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
eigs, _ = torch.linalg.eigh(stats) eigs, _ = torch.linalg.eigh(stats)
else: else:
eigs, _ = torch.symeig(stats) eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt() stats = eigs.abs().sqrt()
except: # noqa except: # noqa
print( print("Error getting eigenvalues, trying another method.")
"Error getting eigenvalues, trying another method." if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
)
if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'):
eigs, _ = torch.linalg.eig(stats) eigs, _ = torch.linalg.eig(stats)
eigs = eigs.abs() eigs = eigs.abs()
else: else:
@ -579,10 +577,15 @@ def attach_diagnostics(
) )
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): if isinstance(o, Tensor) and o.dtype in (
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, torch.float32,
class_name=get_class_name(_module)) torch.float16,
torch.float64,
):
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
if isinstance(_output, tuple) and len(_output) == 1: if isinstance(_output, tuple) and len(_output) == 1:
_output = _output[0] _output = _output[0]
@ -596,9 +599,15 @@ def attach_diagnostics(
) )
elif isinstance(_output, tuple): elif isinstance(_output, tuple):
for i, o in enumerate(_output): for i, o in enumerate(_output):
if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): if isinstance(o, Tensor) and o.dtype in (
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, torch.float32,
class_name=get_class_name(_module)) torch.float16,
torch.float64,
):
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
o, class_name=get_class_name(_module)
)
module.register_forward_hook(forward_hook) module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook) module.register_backward_hook(backward_hook)