This commit is contained in:
marcoyang1998 2023-07-28 14:42:45 +08:00
parent 3705a58624
commit f626ec849b
3 changed files with 633 additions and 443 deletions

View File

@ -16,16 +16,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, List from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss from multi_quantization.prediction import JointCodebookLoss
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -151,7 +151,9 @@ class AsrModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens, middle_out = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
@ -403,10 +405,12 @@ class AsrModel(nn.Module):
Returns: Returns:
The codebook loss value The codebook loss value
""" """
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C) middle_layer_output = middle_out[
0
] # currently only support using output of one layer, (N,T,C)
len_CI = codebook_indexes.size(1) len_CI = codebook_indexes.size(1)
len_mid_layer = middle_layer_output.size(1) len_mid_layer = middle_layer_output.size(1)
ratio = round(len_CI/len_mid_layer) ratio = round(len_CI / len_mid_layer)
if ratio == 1: # Having the same frame rate if ratio == 1: # Having the same frame rate
assert len_CI > len_mid_layer, (len_CI, len_mid_layer) assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
@ -426,9 +430,7 @@ class AsrModel(nn.Module):
return codebook_loss return codebook_loss
@staticmethod @staticmethod
def concat_successive_codebook_indexes( def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second, # Output rate of hubert is 50 frames per second,
# while that of current encoder is 25. # while that of current encoder is 25.
# Following code handling two issues: # Following code handling two issues:
@ -448,8 +450,10 @@ class AsrModel(nn.Module):
# Handling issue 1. # Handling issue 1.
if T >= t_expected * 2: if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :] codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function if (
codebook_indexes = codebook_indexes[:, : t_expected, :] T / t_expected < 1.1
): # To be changed, dirty hack to jump out of this function
codebook_indexes = codebook_indexes[:, :t_expected, :]
assert middle_layer_output.shape[1] == codebook_indexes.shape[1] assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes return codebook_indexes
# Handling issue 2. # Handling issue 2.

View File

@ -636,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -783,17 +783,17 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]: def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
cuts = batch["supervisions"]["cut"] cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation. # -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [ cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts]
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field( codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100 cuts_pre_mixed, "codebook_indexes", pad_value=-100
) )
return codebook_indexes, codebook_indexes_lens return codebook_indexes, codebook_indexes_lens
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -859,17 +859,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
@ -1164,7 +1163,9 @@ def run(rank, world_size, args):
# Note: it's better to set --spec-aug-time-warpi-factor=-1 # Note: it's better to set --spec-aug-time-warpi-factor=-1
# when doing distillation with vq. # when doing distillation with vq.
if params.enable_distillation: if params.enable_distillation:
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation" assert (
args.spec_aug_time_warp_factor < 1
), "Specaug should be disabled during distillation"
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():

File diff suppressed because it is too large Load Diff