mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
format
This commit is contained in:
parent
3705a58624
commit
f626ec849b
@ -16,16 +16,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from multi_quantization.prediction import JointCodebookLoss
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
@ -151,7 +151,9 @@ class AsrModel(nn.Module):
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask)
|
encoder_out, encoder_out_lens, middle_out = self.encoder(
|
||||||
|
x, x_lens, src_key_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
@ -403,10 +405,12 @@ class AsrModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
The codebook loss value
|
The codebook loss value
|
||||||
"""
|
"""
|
||||||
middle_layer_output = middle_out[0] # currently only support using output of one layer, (N,T,C)
|
middle_layer_output = middle_out[
|
||||||
|
0
|
||||||
|
] # currently only support using output of one layer, (N,T,C)
|
||||||
len_CI = codebook_indexes.size(1)
|
len_CI = codebook_indexes.size(1)
|
||||||
len_mid_layer = middle_layer_output.size(1)
|
len_mid_layer = middle_layer_output.size(1)
|
||||||
ratio = round(len_CI/len_mid_layer)
|
ratio = round(len_CI / len_mid_layer)
|
||||||
|
|
||||||
if ratio == 1: # Having the same frame rate
|
if ratio == 1: # Having the same frame rate
|
||||||
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
|
assert len_CI > len_mid_layer, (len_CI, len_mid_layer)
|
||||||
@ -426,9 +430,7 @@ class AsrModel(nn.Module):
|
|||||||
return codebook_loss
|
return codebook_loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def concat_successive_codebook_indexes(
|
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
|
||||||
middle_layer_output, codebook_indexes
|
|
||||||
):
|
|
||||||
# Output rate of hubert is 50 frames per second,
|
# Output rate of hubert is 50 frames per second,
|
||||||
# while that of current encoder is 25.
|
# while that of current encoder is 25.
|
||||||
# Following code handling two issues:
|
# Following code handling two issues:
|
||||||
@ -448,8 +450,10 @@ class AsrModel(nn.Module):
|
|||||||
# Handling issue 1.
|
# Handling issue 1.
|
||||||
if T >= t_expected * 2:
|
if T >= t_expected * 2:
|
||||||
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
||||||
if T / t_expected < 1.1: # To be changed, dirty hack to jump out of this function
|
if (
|
||||||
codebook_indexes = codebook_indexes[:, : t_expected, :]
|
T / t_expected < 1.1
|
||||||
|
): # To be changed, dirty hack to jump out of this function
|
||||||
|
codebook_indexes = codebook_indexes[:, :t_expected, :]
|
||||||
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
||||||
return codebook_indexes
|
return codebook_indexes
|
||||||
# Handling issue 2.
|
# Handling issue 2.
|
||||||
|
@ -636,11 +636,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
params.use_transducer or params.use_ctc
|
f"At least one of them should be True, "
|
||||||
), (f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
@ -783,17 +783,17 @@ def save_checkpoint(
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
|
def extract_codebook_indexes(batch: Dict) -> Tuple[Tensor, Tensor]:
|
||||||
cuts = batch["supervisions"]["cut"]
|
cuts = batch["supervisions"]["cut"]
|
||||||
# -100 is identical to ignore_value in CE loss computation.
|
# -100 is identical to ignore_value in CE loss computation.
|
||||||
cuts_pre_mixed = [
|
cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts]
|
||||||
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
|
||||||
]
|
|
||||||
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
codebook_indexes, codebook_indexes_lens = collate_custom_field(
|
||||||
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
cuts_pre_mixed, "codebook_indexes", pad_value=-100
|
||||||
)
|
)
|
||||||
return codebook_indexes, codebook_indexes_lens
|
return codebook_indexes, codebook_indexes_lens
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -859,17 +859,16 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += (
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
@ -1164,7 +1163,9 @@ def run(rank, world_size, args):
|
|||||||
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
# Note: it's better to set --spec-aug-time-warpi-factor=-1
|
||||||
# when doing distillation with vq.
|
# when doing distillation with vq.
|
||||||
if params.enable_distillation:
|
if params.enable_distillation:
|
||||||
assert args.spec_aug_time_warp_factor < 1, "Specaug should be disabled during distillation"
|
assert (
|
||||||
|
args.spec_aug_time_warp_factor < 1
|
||||||
|
), "Specaug should be disabled during distillation"
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user