Merge branch 'k2-fsa:master' into k2ssl

This commit is contained in:
Yifan Yang 2024-03-31 16:53:25 +08:00 committed by GitHub
commit dfbacbe4dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 578 additions and 14 deletions

View File

@ -90,7 +90,7 @@ jobs:
path: ./*.wav path: ./*.wav
- name: Release exported onnx models - name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2 uses: svenstaro/upload-release-action@v2
with: with:
file_glob: true file_glob: true

View File

@ -74,6 +74,10 @@ to install dependencies of `icefall`_:
pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install k2==1.24.4.dev20231220+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu-cn.html
# Install the latest version of lhotse # Install the latest version of lhotse
pip install git+https://github.com/lhotse-speech/lhotse pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -206,6 +206,9 @@ We will install `k2`_ from pre-compiled wheels by following
.. code-block:: bash .. code-block:: bash
(test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html (test-icefall) kuangfangjun:~$ pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda.html
# For users from China
# 中国国内用户,如果访问不了 huggingface, 请使用
# pip install k2==1.24.3.dev20230725+cuda11.6.torch1.13.0 -f https://k2-fsa.github.io/k2/cuda-cn.html
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Looking in links: https://k2-fsa.github.io/k2/cuda.html Looking in links: https://k2-fsa.github.io/k2/cuda.html

View File

@ -47,7 +47,9 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Decode the model." log "Stage 1: Decode the model."
for t in small, large; do
export CUDA_VISIBLE_DEVICES="0"
for t in small large; do
python ./zipformer/decode.py \ python ./zipformer/decode.py \
--epoch 12 \ --epoch 12 \
--avg 2 \ --avg 2 \
@ -123,6 +125,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--exp-dir zipformer/exp_finetune \ --exp-dir zipformer/exp_finetune \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--use-fp16 1 \ --use-fp16 1 \
--use-mux 1 \
--decoder-dim 320 \ --decoder-dim 320 \
--joiner-dim 320 \ --joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \ --num-encoder-layers 1,1,1,1,1,1 \
@ -139,7 +142,8 @@ fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model." log "Stage 1: Decode the finetuned model."
for t in small, large; do export CUDA_VISIBLE_DEVICES="0"
for t in small large; do
python ./zipformer/decode.py \ python ./zipformer/decode.py \
--epoch 10 \ --epoch 10 \
--avg 2 \ --avg 2 \

View File

@ -788,7 +788,7 @@ class Zipformer2EncoderLayer(nn.Module):
selected_attn_weights = attn_weights[0:1] selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
pass pass
elif not self.training and random.random() < float(self.const_attention_rate): elif self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to # Make attention weights constant. The intention is to
# encourage these modules to do something similar to an # encourage these modules to do something similar to an
# averaging-over-time operation. # averaging-over-time operation.

View File

@ -27,11 +27,13 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./zipformer/export-onnx.py \ ./zipformer_adapter/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-adapters 1 \
--adapter-dim 32 \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \ --num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \ --downsampling-factor "1,2,4,8,4,2" \
@ -131,7 +133,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="zipformer/exp", default="zipformer_adapter/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,

View File

@ -0,0 +1,520 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer_adapter/export.py \
--exp-dir ./zipformer_adapter/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--use-adapters 1 \
--adapter-dim 16 \
--avg 9 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer_adapter/export.py \
--exp-dir ./zipformer_adapter/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--use-adapters 1 \
--adapter-dim 16 \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer_adapter/export.py \
--exp-dir ./zipformer_adapter/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--use-adapters 1 \
--adapter-dim 16 \
--avg 9
- For streaming model:
./zipformer_adapter/export.py \
--exp-dir ./zipformer_adapter/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--use-adapters 1 \
--adapter-dim 16 \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer_adapter/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer_adapter/decode_gigaspeech.py \
--exp-dir ./zipformer_adapter/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--use-adapters 1 \
--adapter-dim 16 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer_adapter/decode.py` and `zipformer_adapter/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer_adapter/decode_gigaspeech.py \
--exp-dir ./zipformer_adapter/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
# chunk-wise streaming decoding
./zipformer_adapter/streaming_decode.py \
--exp-dir ./zipformer_adapter/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_adapter/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1004,6 +1004,9 @@ class Zipformer2EncoderLayer(nn.Module):
) )
src = src + self_attn src = src + self_attn
if self.use_adapters and self.post_sa_adapter is not None:
src = self.post_sa_adapter(src)
src_conv, cached_conv1 = self.conv_module1.streaming_forward( src_conv, cached_conv1 = self.conv_module1.streaming_forward(
src, src,
cache=cached_conv1, cache=cached_conv1,
@ -1016,6 +1019,9 @@ class Zipformer2EncoderLayer(nn.Module):
# bypass in the middle of the layer. # bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src) src = self.bypass_mid(src_orig, src)
if self.use_adapters and self.mid_adapter is not None:
src = self.mid_adapter(src)
self_attn, cached_val2 = self.self_attn2.streaming_forward( self_attn, cached_val2 = self.self_attn2.streaming_forward(
src, src,
attn_weights=attn_weights, attn_weights=attn_weights,
@ -1031,12 +1037,18 @@ class Zipformer2EncoderLayer(nn.Module):
) )
src = src + src_conv src = src + src_conv
if self.use_adapters and self.post_conv_adapter is not None:
src = self.post_conv_adapter(src)
src = src + self.feed_forward3(src) src = src + self.feed_forward3(src)
src = self.norm(src) src = self.norm(src)
src = self.bypass(src_orig, src) src = self.bypass(src_orig, src)
if self.use_adapters and self.adapter is not None:
src = self.adapter(src)
return ( return (
src, src,
cached_key, cached_key,

View File

@ -48,7 +48,8 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Decode the model." log "Stage 1: Decode the model."
for t in small, large; do export CUDA_VISIBLE_DEVICES="0"
for t in small large; do
python ./zipformer/decode.py \ python ./zipformer/decode.py \
--epoch 18 \ --epoch 18 \
--avg 2 \ --avg 2 \
@ -126,6 +127,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--lang-dir ./data/lang_partial_tone \ --lang-dir ./data/lang_partial_tone \
--pinyin-type partial_with_tone \ --pinyin-type partial_with_tone \
--use-fp16 1 \ --use-fp16 1 \
--use-mux 1 \
--decoder-dim 320 \ --decoder-dim 320 \
--joiner-dim 320 \ --joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \ --num-encoder-layers 1,1,1,1,1,1 \
@ -142,7 +144,8 @@ fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model." log "Stage 1: Decode the finetuned model."
for t in small, large; do export CUDA_VISIBLE_DEVICES="0"
for t in small large; do
python ./zipformer/decode.py \ python ./zipformer/decode.py \
--epoch 10 \ --epoch 10 \
--avg 2 \ --avg 2 \

View File

@ -1,6 +1,7 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey # Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
# Zengwei Yao # Zengwei Yao
# Mingshuang Luo) # Mingshuang Luo,
# Zengrui Jin,)
# #
# See ../LICENSE for clarification regarding multiple authors # See ../LICENSE for clarification regarding multiple authors
# #
@ -16,9 +17,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -653,7 +655,13 @@ def attach_diagnostics(
_model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
_model_diagnostic[f"{_name}.param_grad"].accumulate(grad) _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
parameter.register_hook(param_backward_hook) try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
return ans return ans

View File

@ -1,4 +1,6 @@
# Copyright 2021-2022 Xiaomi Corporation (authors: Zengwei Yao, Daniel Povey) # Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
# Daniel Povey,
# Zengrui Jin,)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #
@ -77,7 +79,13 @@ def register_inf_check_hooks(model: nn.Module) -> None:
if not torch.isfinite(grad.to(torch.float32).sum()): if not torch.isfinite(grad.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.param_grad is not finite") logging.warning(f"The sum of {_name}.param_grad is not finite")
parameter.register_hook(param_backward_hook) try:
parameter.register_hook(param_backward_hook)
except:
logging.warning(
f"Warning: could not register backward hook for parameter {name}, "
f"it might not be differentiable."
)
def _test_inf_check_hooks(): def _test_inf_check_hooks():