mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Merge d941c516c00d260875945064edc80eb8a00add81 into 3199058194a48d45aeee740f2aa9bdbef0bec29d
This commit is contained in:
commit
00fefa2b22
@ -25,6 +25,7 @@
|
|||||||
|
|
||||||
stage=0
|
stage=0
|
||||||
stop_stage=4
|
stop_stage=4
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
# Set the GPUs available.
|
# Set the GPUs available.
|
||||||
# This script requires at least one GPU.
|
# This script requires at least one GPU.
|
||||||
@ -32,7 +33,7 @@ stop_stage=4
|
|||||||
# even you only have ONE GPU. It needed by CodebookIndexExtractor to determine numbert of jobs to extract codebook indexes parallelly.
|
# even you only have ONE GPU. It needed by CodebookIndexExtractor to determine numbert of jobs to extract codebook indexes parallelly.
|
||||||
|
|
||||||
# Suppose only one GPU exists:
|
# Suppose only one GPU exists:
|
||||||
# export CUDA_VISIBLE_DEVICES="0"
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
#
|
#
|
||||||
# Suppose GPU 2,3,4,5 are available.
|
# Suppose GPU 2,3,4,5 are available.
|
||||||
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
@ -154,9 +155,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
mkdir -p codebook_dir
|
mkdir -p codebook_dir
|
||||||
codebook_download_dir=$exp_dir/download_codebook
|
codebook_download_dir=$exp_dir/download_codebook
|
||||||
if [ -d $codebook_download_dir ]; then
|
if [ -d $codebook_download_dir ]; then
|
||||||
log "$codebook_download_dir exists, you should remove it first."
|
log "$codebook_download_dir exists, skip downloading it."
|
||||||
exit 1
|
else
|
||||||
fi
|
|
||||||
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
||||||
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
||||||
# The codebook indexes are generated using lhotse 1.11.0, to avoid
|
# The codebook indexes are generated using lhotse 1.11.0, to avoid
|
||||||
@ -165,16 +165,25 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
if [ "$lhotse_version" == "False" ]; then
|
if [ "$lhotse_version" == "False" ]; then
|
||||||
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
|
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
|
||||||
fi
|
fi
|
||||||
git lfs install
|
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||||
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
pushd $codebook_download_dir
|
||||||
|
if [ "$full_libri" == "False" ]; then
|
||||||
|
log "Only download the train-clean-100 subset"
|
||||||
|
git lfs pull --include "*clean-100*"
|
||||||
|
else
|
||||||
|
log "Download the full training set"
|
||||||
|
git lfs fetch --all
|
||||||
|
fi
|
||||||
|
popd
|
||||||
|
fi
|
||||||
|
|
||||||
vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/
|
vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/
|
||||||
mkdir -p $vq_fbank
|
mkdir -p $vq_fbank
|
||||||
mv $codebook_download_dir/*.jsonl.gz $vq_fbank
|
mv $codebook_download_dir/*.jsonl.gz $vq_fbank
|
||||||
mkdir -p $codebook_dir/splits4
|
mkdir -p $codebook_dir/splits4
|
||||||
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
|
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
|
||||||
log "Remove $codebook_download_dir"
|
# log "Remove $codebook_download_dir"
|
||||||
rm -rf $codebook_download_dir
|
# rm -rf $codebook_download_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
./pruned_transducer_stateless6/extract_codebook_index.py \
|
./pruned_transducer_stateless6/extract_codebook_index.py \
|
||||||
|
@ -27,8 +27,6 @@ from torch import nn
|
|||||||
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.lm_wrapper import LmScorer
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.rnn_lm.model import RnnLmModel
|
|
||||||
from icefall.transformer_lm.model import TransformerLM
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
DecodingResults,
|
DecodingResults,
|
||||||
add_eos,
|
add_eos,
|
||||||
|
@ -16,15 +16,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
@ -39,12 +40,15 @@ class AsrModel(nn.Module):
|
|||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
use_transducer: bool = True,
|
use_transducer: bool = True,
|
||||||
use_ctc: bool = False,
|
use_ctc: bool = False,
|
||||||
|
num_codebooks: int = 8,
|
||||||
|
cb_input_dim: int = 384,
|
||||||
):
|
):
|
||||||
"""A joint CTC & Transducer ASR model.
|
"""A joint CTC & Transducer ASR model.
|
||||||
|
|
||||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||||
|
- Potentially with MVQ knowledge distillation (https://arxiv.org/abs/2211.00508)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_embed:
|
encoder_embed:
|
||||||
@ -70,6 +74,10 @@ class AsrModel(nn.Module):
|
|||||||
Whether use transducer head. Default: True.
|
Whether use transducer head. Default: True.
|
||||||
use_ctc:
|
use_ctc:
|
||||||
Whether use CTC head. Default: False.
|
Whether use CTC head. Default: False.
|
||||||
|
num_codebooks:
|
||||||
|
Greater than 0 if we want to do MVQ knowledge distillation.
|
||||||
|
cb_input_dim:
|
||||||
|
The input dimension to the codebook loss module.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -111,6 +119,12 @@ class AsrModel(nn.Module):
|
|||||||
nn.LogSoftmax(dim=-1),
|
nn.LogSoftmax(dim=-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_codebooks > 0:
|
||||||
|
self.codebook_loss_net = JointCodebookLoss(
|
||||||
|
predictor_channels=cb_input_dim,
|
||||||
|
num_codebooks=num_codebooks,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_encoder(
|
def forward_encoder(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -127,6 +141,8 @@ class AsrModel(nn.Module):
|
|||||||
Encoder output, of shape (N, T, C).
|
Encoder output, of shape (N, T, C).
|
||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
Encoder output lengths, of shape (N,).
|
Encoder output lengths, of shape (N,).
|
||||||
|
saved_embeddings:
|
||||||
|
The embeddings from the middle layers
|
||||||
"""
|
"""
|
||||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
x, x_lens = self.encoder_embed(x, x_lens)
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
@ -135,12 +151,14 @@ class AsrModel(nn.Module):
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
encoder_out, encoder_out_lens, middle_out = self.encoder(
|
||||||
|
x, x_lens, src_key_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens, middle_out
|
||||||
|
|
||||||
def forward_ctc(
|
def forward_ctc(
|
||||||
self,
|
self,
|
||||||
@ -180,6 +198,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
codebook_indexes: torch.Tensor = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute Transducer loss.
|
"""Compute Transducer loss.
|
||||||
Args:
|
Args:
|
||||||
@ -286,6 +305,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
codebook_indexes: torch.Tensor = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -306,9 +326,12 @@ class AsrModel(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
codebook_indexes:
|
||||||
|
The codebook indexes to be predicted. Only used when doing knowledge
|
||||||
|
distillation with MVQ
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer losses and CTC loss,
|
Return the transducer losses and CTC loss, and potentially codebook loss
|
||||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
@ -323,7 +346,7 @@ class AsrModel(nn.Module):
|
|||||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
@ -355,4 +378,85 @@ class AsrModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
ctc_loss = torch.empty(0)
|
ctc_loss = torch.empty(0)
|
||||||
|
|
||||||
return simple_loss, pruned_loss, ctc_loss
|
if self.training and hasattr(self, "codebook_loss_net"):
|
||||||
|
assert codebook_indexes is not None
|
||||||
|
codebook_loss = self.forward_codebook(
|
||||||
|
middle_out=middle_out,
|
||||||
|
codebook_indexes=codebook_indexes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
codebook_loss = torch.empty(0)
|
||||||
|
|
||||||
|
return simple_loss, pruned_loss, ctc_loss, codebook_loss
|
||||||
|
|
||||||
|
def forward_codebook(
|
||||||
|
self,
|
||||||
|
middle_out: List[torch.Tensor],
|
||||||
|
codebook_indexes: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate the codebook loss for the model (knowledge distillation)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
middle_out (List[torch.Tensor]):
|
||||||
|
The embeddings extracted from the middle layer of the zipformer encoder
|
||||||
|
codebook_indexes (torch.Tensor):
|
||||||
|
The encoded codebook indexes for knowledge distillation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The codebook loss value
|
||||||
|
"""
|
||||||
|
middle_layer_output = middle_out[
|
||||||
|
0
|
||||||
|
] # currently only support using output of one layer, (N,T,C)
|
||||||
|
len_CI = codebook_indexes.size(1)
|
||||||
|
len_mid_layer = middle_layer_output.size(1)
|
||||||
|
ratio = round(len_CI / len_mid_layer)
|
||||||
|
|
||||||
|
if ratio == 1: # Having the same frame rate
|
||||||
|
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
|
||||||
|
codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
|
||||||
|
assert codebook_indexes.size(1) == middle_layer_output.size(1)
|
||||||
|
codebook_loss = self.codebook_loss_net(
|
||||||
|
middle_layer_output, codebook_indexes
|
||||||
|
)
|
||||||
|
elif ratio == 2:
|
||||||
|
codebook_indexes = self.concat_successive_codebook_indexes(
|
||||||
|
middle_layer_output, codebook_indexes
|
||||||
|
)
|
||||||
|
codebook_loss = self.codebook_loss_net(
|
||||||
|
middle_layer_output, codebook_indexes
|
||||||
|
)
|
||||||
|
|
||||||
|
return codebook_loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
|
||||||
|
# Output rate of hubert is 50 frames per second,
|
||||||
|
# while that of current encoder is 25.
|
||||||
|
# Following code handling two issues:
|
||||||
|
# 1.
|
||||||
|
# Roughly speaking, to generate another frame output,
|
||||||
|
# hubert needes extra two frames,
|
||||||
|
# while current encoder needs extra four frames.
|
||||||
|
# Suppose there are only extra three frames provided,
|
||||||
|
# hubert will generate another frame while current encoder does nothing.
|
||||||
|
# 2.
|
||||||
|
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
||||||
|
# learns from 50 frames teacher output, two successive frames of teacher model
|
||||||
|
# output is concatenated together.
|
||||||
|
t_expected = middle_layer_output.shape[1]
|
||||||
|
N, T, C = codebook_indexes.shape
|
||||||
|
assert T >= t_expected, (T, t_expected)
|
||||||
|
# Handling issue 1.
|
||||||
|
if T >= t_expected * 2:
|
||||||
|
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
||||||
|
if (
|
||||||
|
T / t_expected < 1.1
|
||||||
|
): # To be changed, dirty hack to jump out of this function
|
||||||
|
codebook_indexes = codebook_indexes[:, :t_expected, :]
|
||||||
|
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||||
|
return codebook_indexes
|
||||||
|
# Handling issue 2.
|
||||||
|
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
|
||||||
|
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||||
|
return codebook_indexes
|
||||||
|
@ -68,7 +68,8 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut, MonoCut
|
||||||
|
from lhotse.dataset.collation import collate_custom_field
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -403,6 +404,34 @@ def get_parser():
|
|||||||
help="Scale for CTC loss.",
|
help="Scale for CTC loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-distillation",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to eanble distillation.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--codebook-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The scale of codebook loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-codebooks",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of codebooks used for the extracted CI",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distillation-layer",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Where to perform MVQ-KD",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -579,6 +608,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
|
middle_output_layer=params.distillation_layer
|
||||||
|
if params.enable_distillation
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -604,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
params.use_transducer or params.use_ctc
|
f"At least one of them should be True, "
|
||||||
), (f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
@ -630,6 +662,8 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
use_transducer=params.use_transducer,
|
||||||
use_ctc=params.use_ctc,
|
use_ctc=params.use_ctc,
|
||||||
|
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
|
||||||
|
cb_input_dim=_to_int_tuple(params.encoder_dim)[params.distillation_layer],
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -750,6 +784,16 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
|
||||||
|
cuts = batch["supervisions"]["cut"]
|
||||||
|
# -100 is identical to ignore_value in CE loss computation.
|
||||||
|
cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts]
|
||||||
|
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
||||||
|
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
||||||
|
)
|
||||||
|
return codebook_indexes, codebook_indexes_lens
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -791,14 +835,21 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
|
if is_training and params.enable_distillation:
|
||||||
|
codebook_indexes, _ = extract_codebook_indexes(batch)
|
||||||
|
codebook_indexes = codebook_indexes.to(device)
|
||||||
|
else:
|
||||||
|
codebook_indexes = None
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss, codebook_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
codebook_indexes=codebook_indexes,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -808,21 +859,23 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += (
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
|
||||||
|
if is_training and params.enable_distillation:
|
||||||
|
loss += params.codebook_loss_scale * codebook_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -837,6 +890,8 @@ def compute_loss(
|
|||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
if is_training and params.enable_distillation:
|
||||||
|
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -1105,6 +1160,13 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
|
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||||
|
# when doing distillation with vq.
|
||||||
|
if params.enable_distillation:
|
||||||
|
assert (
|
||||||
|
args.spec_aug_time_warp_factor < 1
|
||||||
|
), "Specaug should be disabled during distillation"
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
@ -1234,14 +1296,14 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
sp=sp,
|
# sp=sp,
|
||||||
params=params,
|
# params=params,
|
||||||
)
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user