This commit is contained in:
marcoyang1998 2023-07-28 14:42:45 +08:00
parent 3705a58624
commit f626ec849b
3 changed files with 633 additions and 443 deletions

View File

@ -16,16 +16,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, List from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss from multi_quantization.prediction import JointCodebookLoss
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -141,8 +141,8 @@ class AsrModel(nn.Module):
Encoder output, of shape (N, T, C). Encoder output, of shape (N, T, C).
encoder_out_lens: encoder_out_lens:
Encoder output lengths, of shape (N,). Encoder output lengths, of shape (N,).
saved_embeddings: saved_embeddings:
The embeddings from the middle layers The embeddings from the middle layers
""" """
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens) x, x_lens = self.encoder_embed(x, x_lens)
@ -151,7 +151,9 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens, middle_out = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
@ -324,9 +326,9 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
codebook_indexes: codebook_indexes:
The codebook indexes to be predicted. Only used when doing knowledge The codebook indexes to be predicted. Only used when doing knowledge
distillation with MVQ distillation with MVQ
Returns: Returns:
Return the transducer losses and CTC loss, and potentially codebook loss Return the transducer losses and CTC loss, and potentially codebook loss
in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss) in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
@ -379,9 +381,9 @@ class AsrModel(nn.Module):
if self.training and hasattr(self, "codebook_loss_net"): if self.training and hasattr(self, "codebook_loss_net"):
assert codebook_indexes is not None assert codebook_indexes is not None
codebook_loss = self.forward_codebook( codebook_loss = self.forward_codebook(
middle_out=middle_out, middle_out=middle_out,
codebook_indexes=codebook_indexes, codebook_indexes=codebook_indexes,
) )
else: else:
codebook_loss = torch.empty(0) codebook_loss = torch.empty(0)
@ -394,21 +396,23 @@ class AsrModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
"""Calculate the codebook loss for the model (knowledge distillation) """Calculate the codebook loss for the model (knowledge distillation)
Args: Args:
middle_out (List[torch.Tensor]): middle_out (List[torch.Tensor]):
The embeddings extracted from the middle layer of the zipformer encoder The embeddings extracted from the middle layer of the zipformer encoder
codebook_indexes (torch.Tensor): codebook_indexes (torch.Tensor):
The encoded codebook indexes for knowledge distillation The encoded codebook indexes for knowledge distillation
Returns: Returns:
The codebook loss value The codebook loss value
""" """
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C) middle_layer_output = middle_out[
0
] # currently only support using output of one layer, (N,T,C)
len_CI = codebook_indexes.size(1) len_CI = codebook_indexes.size(1)
len_mid_layer = middle_layer_output.size(1) len_mid_layer = middle_layer_output.size(1)
ratio = round(len_CI/len_mid_layer) ratio = round(len_CI / len_mid_layer)
if ratio == 1: # Having the same frame rate if ratio == 1: # Having the same frame rate
assert len_CI > len_mid_layer, (len_CI, len_mid_layer) assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
codebook_indexes = codebook_indexes[:, :len_mid_layer, :] codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
assert codebook_indexes.size(1) == middle_layer_output.size(1) assert codebook_indexes.size(1) == middle_layer_output.size(1)
@ -426,9 +430,7 @@ class AsrModel(nn.Module):
return codebook_loss return codebook_loss
@staticmethod @staticmethod
def concat_successive_codebook_indexes( def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second, # Output rate of hubert is 50 frames per second,
# while that of current encoder is 25. # while that of current encoder is 25.
# Following code handling two issues: # Following code handling two issues:
@ -448,10 +450,12 @@ class AsrModel(nn.Module):
# Handling issue 1. # Handling issue 1.
if T >= t_expected * 2: if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :] codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function if (
codebook_indexes = codebook_indexes[:, : t_expected, :] T / t_expected < 1.1
assert middle_layer_output.shape[1] == codebook_indexes.shape[1] ): # To be changed, dirty hack to jump out of this function
return codebook_indexes codebook_indexes = codebook_indexes[:, :t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
# Handling issue 2. # Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2) codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1] assert middle_layer_output.shape[1] == codebook_indexes.shape[1]

View File

@ -636,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -783,17 +783,17 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]: def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
cuts = batch["supervisions"]["cut"] cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation. # -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [ cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts]
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field( codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100 cuts_pre_mixed, "codebook_indexes", pad_value=-100
) )
return codebook_indexes, codebook_indexes_lens return codebook_indexes, codebook_indexes_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -859,17 +859,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
@ -1164,7 +1163,9 @@ def run(rank, world_size, args):
# Note: it's better to set --spec-aug-time-warpi-factor=-1 # Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq. # when doing distillation with vq.
if params.enable_distillation: if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation" assert (
args.spec_aug_time_warp_factor < 1
), "Specaug should be disabled during distillation"
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():

File diff suppressed because it is too large Load Diff