Merge d941c516c00d260875945064edc80eb8a00add81 into 3199058194a48d45aeee740f2aa9bdbef0bec29d

This commit is contained in:
marcoyang1998 2023-09-11 18:53:54 +08:00 committed by GitHub
commit 00fefa2b22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 799 additions and 424 deletions

View File

@ -25,6 +25,7 @@
stage=0 stage=0
stop_stage=4 stop_stage=4
. shared/parse_options.sh || exit 1
# Set the GPUs available. # Set the GPUs available.
# This script requires at least one GPU. # This script requires at least one GPU.
@ -32,7 +33,7 @@ stop_stage=4
# even you only have ONE GPU. It needed by CodebookIndexExtractor to determine numbert of jobs to extract codebook indexes parallelly. # even you only have ONE GPU. It needed by CodebookIndexExtractor to determine numbert of jobs to extract codebook indexes parallelly.
# Suppose only one GPU exists: # Suppose only one GPU exists:
# export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
# #
# Suppose GPU 2,3,4,5 are available. # Suppose GPU 2,3,4,5 are available.
# export CUDA_VISIBLE_DEVICES="0,1,2,3" # export CUDA_VISIBLE_DEVICES="0,1,2,3"
@ -154,27 +155,35 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
mkdir -p codebook_dir mkdir -p codebook_dir
codebook_download_dir=$exp_dir/download_codebook codebook_download_dir=$exp_dir/download_codebook
if [ -d $codebook_download_dir ]; then if [ -d $codebook_download_dir ]; then
log "$codebook_download_dir exists, you should remove it first." log "$codebook_download_dir exists, skip downloading it."
exit 1 else
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
# The codebook indexes are generated using lhotse 1.11.0, to avoid
# potential issues, we recommend you to use lhotse version >= 1.11.0
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
if [ "$lhotse_version" == "False" ]; then
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
fi
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
pushd $codebook_download_dir
if [ "$full_libri" == "False" ]; then
log "Only download the train-clean-100 subset"
git lfs pull --include "*clean-100*"
else
log "Download the full training set"
git lfs fetch --all
fi
popd
fi fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
# The codebook indexes are generated using lhotse 1.11.0, to avoid
# potential issues, we recommend you to use lhotse version >= 1.11.0
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
if [ "$lhotse_version" == "False" ]; then
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
fi
git lfs install
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/ vq_fbank=data/vq_fbank_layer${embedding_layer}_cb${num_codebooks}/
mkdir -p $vq_fbank mkdir -p $vq_fbank
mv $codebook_download_dir/*.jsonl.gz $vq_fbank mv $codebook_download_dir/*.jsonl.gz $vq_fbank
mkdir -p $codebook_dir/splits4 mkdir -p $codebook_dir/splits4
mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ mv $codebook_download_dir/*.h5 $codebook_dir/splits4/
log "Remove $codebook_download_dir" # log "Remove $codebook_download_dir"
rm -rf $codebook_download_dir # rm -rf $codebook_download_dir
fi fi
./pruned_transducer_stateless6/extract_codebook_index.py \ ./pruned_transducer_stateless6/extract_codebook_index.py \

View File

@ -27,8 +27,6 @@ from torch import nn
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel
from icefall.transformer_lm.model import TransformerLM
from icefall.utils import ( from icefall.utils import (
DecodingResults, DecodingResults,
add_eos, add_eos,

View File

@ -16,15 +16,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -39,12 +40,15 @@ class AsrModel(nn.Module):
vocab_size: int = 500, vocab_size: int = 500,
use_transducer: bool = True, use_transducer: bool = True,
use_ctc: bool = False, use_ctc: bool = False,
num_codebooks: int = 8,
cb_input_dim: int = 384,
): ):
"""A joint CTC & Transducer ASR model. """A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
- Potentially with MVQ knowledge distillation (https://arxiv.org/abs/2211.00508)
Args: Args:
encoder_embed: encoder_embed:
@ -70,6 +74,10 @@ class AsrModel(nn.Module):
Whether use transducer head. Default: True. Whether use transducer head. Default: True.
use_ctc: use_ctc:
Whether use CTC head. Default: False. Whether use CTC head. Default: False.
num_codebooks:
Greater than 0 if we want to do MVQ knowledge distillation.
cb_input_dim:
The input dimension to the codebook loss module.
""" """
super().__init__() super().__init__()
@ -111,6 +119,12 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1), nn.LogSoftmax(dim=-1),
) )
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=cb_input_dim,
num_codebooks=num_codebooks,
)
def forward_encoder( def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -127,6 +141,8 @@ class AsrModel(nn.Module):
Encoder output, of shape (N, T, C). Encoder output, of shape (N, T, C).
encoder_out_lens: encoder_out_lens:
Encoder output lengths, of shape (N,). Encoder output lengths, of shape (N,).
saved_embeddings:
The embeddings from the middle layers
""" """
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens) x, x_lens = self.encoder_embed(x, x_lens)
@ -135,12 +151,14 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens, middle_out = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens, middle_out
def forward_ctc( def forward_ctc(
self, self,
@ -180,6 +198,7 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss. """Compute Transducer loss.
Args: Args:
@ -286,6 +305,7 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
codebook_indexes: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -306,9 +326,12 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
codebook_indexes:
The codebook indexes to be predicted. Only used when doing knowledge
distillation with MVQ
Returns: Returns:
Return the transducer losses and CTC loss, Return the transducer losses and CTC loss, and potentially codebook loss
in form of (simple_loss, pruned_loss, ctc_loss) in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
Note: Note:
Regarding am_scale & lm_scale, it will make the loss-function one of Regarding am_scale & lm_scale, it will make the loss-function one of
@ -323,7 +346,7 @@ class AsrModel(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
# Compute encoder outputs # Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens, middle_out = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
@ -355,4 +378,85 @@ class AsrModel(nn.Module):
else: else:
ctc_loss = torch.empty(0) ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss if self.training and hasattr(self, "codebook_loss_net"):
assert codebook_indexes is not None
codebook_loss = self.forward_codebook(
middle_out=middle_out,
codebook_indexes=codebook_indexes,
)
else:
codebook_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, codebook_loss
def forward_codebook(
self,
middle_out: List[torch.Tensor],
codebook_indexes: torch.Tensor,
) -> torch.Tensor:
"""Calculate the codebook loss for the model (knowledge distillation)
Args:
middle_out (List[torch.Tensor]):
The embeddings extracted from the middle layer of the zipformer encoder
codebook_indexes (torch.Tensor):
The encoded codebook indexes for knowledge distillation
Returns:
The codebook loss value
"""
middle_layer_output = middle_out[
0
] # currently only support using output of one layer, (N,T,C)
len_CI = codebook_indexes.size(1)
len_mid_layer = middle_layer_output.size(1)
ratio = round(len_CI / len_mid_layer)
if ratio == 1: # Having the same frame rate
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
assert codebook_indexes.size(1) == middle_layer_output.size(1)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
elif ratio == 2:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
return codebook_loss
@staticmethod
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape
assert T >= t_expected, (T, t_expected)
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
if (
T / t_expected < 1.1
): # To be changed, dirty hack to jump out of this function
codebook_indexes = codebook_indexes[:, :t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes

View File

@ -68,7 +68,8 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -403,6 +404,34 @@ def get_parser():
help="Scale for CTC loss.", help="Scale for CTC loss.",
) )
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks used for the extracted CI",
)
parser.add_argument(
"--distillation-layer",
type=int,
default=4,
help="Where to perform MVQ-KD",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -579,6 +608,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal, causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size), chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames), left_context_frames=_to_int_tuple(params.left_context_frames),
middle_output_layer=params.distillation_layer
if params.enable_distillation
else None,
) )
return encoder return encoder
@ -604,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -630,6 +662,8 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
use_transducer=params.use_transducer, use_transducer=params.use_transducer,
use_ctc=params.use_ctc, use_ctc=params.use_ctc,
num_codebooks=params.num_codebooks if params.enable_distillation else 0,
cb_input_dim=_to_int_tuple(params.encoder_dim)[params.distillation_layer],
) )
return model return model
@ -750,6 +784,16 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -791,14 +835,21 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( simple_loss, pruned_loss, ctc_loss, codebook_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
codebook_indexes=codebook_indexes,
) )
loss = 0.0 loss = 0.0
@ -808,21 +859,23 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if is_training and params.enable_distillation:
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -837,6 +890,8 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if is_training and params.enable_distillation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info return loss, info
@ -1105,6 +1160,13 @@ def run(rank, world_size, args):
else: else:
tb_writer = None tb_writer = None
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
if params.enable_distillation:
assert (
args.spec_aug_time_warp_factor < 1
), "Specaug should be disabled during distillation"
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
@ -1234,14 +1296,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: # if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(
model=model, # model=model,
train_dl=train_dl, # train_dl=train_dl,
optimizer=optimizer, # optimizer=optimizer,
sp=sp, # sp=sp,
params=params, # params=params,
) # )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:

File diff suppressed because it is too large Load Diff