This commit is contained in:
marcoyang1998 2023-07-28 14:42:45 +08:00
parent 3705a58624
commit f626ec849b
3 changed files with 633 additions and 443 deletions

View File

@ -16,16 +16,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module):
@ -118,7 +118,7 @@ class AsrModel(nn.Module):
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=cb_input_dim,
@ -141,8 +141,8 @@ class AsrModel(nn.Module):
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
saved_embeddings:
The embeddings from the middle layers
saved_embeddings:
The embeddings from the middle layers
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
@ -151,7 +151,9 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens, middle_out = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
@ -324,9 +326,9 @@ class AsrModel(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
codebook_indexes:
The codebook indexes to be predicted. Only used when doing knowledge
distillation with MVQ
codebook_indexes:
The codebook indexes to be predicted. Only used when doing knowledge
distillation with MVQ
Returns:
Return the transducer losses and CTC loss, and potentially codebook loss
in form of (simple_loss, pruned_loss, ctc_loss, codebook_loss)
@ -375,18 +377,18 @@ class AsrModel(nn.Module):
)
else:
ctc_loss = torch.empty(0)
if self.training and hasattr(self, "codebook_loss_net"):
assert codebook_indexes is not None
codebook_loss = self.forward_codebook(
middle_out=middle_out,
codebook_indexes=codebook_indexes,
)
middle_out=middle_out,
codebook_indexes=codebook_indexes,
)
else:
codebook_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, codebook_loss
def forward_codebook(
self,
middle_out: List[torch.Tensor],
@ -394,21 +396,23 @@ class AsrModel(nn.Module):
) -> torch.Tensor:
"""Calculate the codebook loss for the model (knowledge distillation)
Args:
middle_out (List[torch.Tensor]):
The embeddings extracted from the middle layer of the zipformer encoder
codebook_indexes (torch.Tensor):
The encoded codebook indexes for knowledge distillation
Args:
middle_out (List[torch.Tensor]):
The embeddings extracted from the middle layer of the zipformer encoder
codebook_indexes (torch.Tensor):
The encoded codebook indexes for knowledge distillation
Returns:
The codebook loss value
"""
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C)
Returns:
The codebook loss value
"""
middle_layer_output = middle_out[
0
] # currently only support using output of one layer, (N,T,C)
len_CI = codebook_indexes.size(1)
len_mid_layer = middle_layer_output.size(1)
ratio = round(len_CI/len_mid_layer)
if ratio == 1: # Having the same frame rate
ratio = round(len_CI / len_mid_layer)
if ratio == 1: # Having the same frame rate
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
codebook_indexes = codebook_indexes[:, :len_mid_layer, :]
assert codebook_indexes.size(1) == middle_layer_output.size(1)
@ -422,13 +426,11 @@ class AsrModel(nn.Module):
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
return codebook_loss
@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
@ -448,10 +450,12 @@ class AsrModel(nn.Module):
# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function
codebook_indexes = codebook_indexes[:, : t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
if (
T / t_expected < 1.1
): # To be changed, dirty hack to jump out of this function
codebook_indexes = codebook_indexes[:, :t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]

View File

@ -403,28 +403,28 @@ def get_parser():
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--enable-distillation",
type=str2bool,
default=True,
help="Whether to eanble distillation.",
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=16,
help="Number of codebooks used for the extracted CI",
)
parser.add_argument(
"--distillation-layer",
type=int,
@ -636,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
@ -783,17 +783,17 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -834,7 +834,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
if is_training and params.enable_distillation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
@ -859,21 +859,20 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
if is_training and params.enable_distillation:
loss += params.codebook_loss_scale * codebook_loss
@ -1164,7 +1163,9 @@ def run(rank, world_size, args):
# Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq.
if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation"
assert (
args.spec_aug_time_warp_factor < 1
), "Specaug should be disabled during distillation"
device = torch.device("cpu")
if torch.cuda.is_available():

File diff suppressed because it is too large Load Diff