From 802349302940bb83870724c0f70b1cf47dc081c2 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:09:22 +0800 Subject: [PATCH] update --- egs/librispeech/SSL/hubert/dataset.py | 55 +-------------- egs/librispeech/SSL/hubert/decode.py | 37 +++++----- egs/librispeech/SSL/hubert/finetune.py | 98 +++++++++++++------------- egs/librispeech/SSL/hubert/model.py | 10 ++- 4 files changed, 77 insertions(+), 123 deletions(-) diff --git a/egs/librispeech/SSL/hubert/dataset.py b/egs/librispeech/SSL/hubert/dataset.py index c3442df51..d97fe9945 100644 --- a/egs/librispeech/SSL/hubert/dataset.py +++ b/egs/librispeech/SSL/hubert/dataset.py @@ -25,53 +25,6 @@ from torch.utils.data.dataloader import default_collate from transformers import Wav2Vec2FeatureExtractor -class HubertDataset(torch.utils.data.Dataset): - """ - In this implementation, there will always be a single channel. - - Returns: - - .. code-block:: - - { - 'audio': (B x NumSamples) float tensor - 'audio_lens': (B, ) int tensor - } - """ - - def __init__(self, collate: bool = True) -> None: - super().__init__() - self.feature_extractor = Wav2Vec2FeatureExtractor( - feature_size=1, - sampling_rate=16000, - padding_side="right", - padding_value=0.0, - do_normalize=True, - return_attention_mask=True, - ) - - def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: - self._validate(cuts) - audio, _ = read_audio_from_cuts(cuts, return_tensors=False) - audio = self.feature_extractor( - audio, - padding=True, - return_tensors="pt", - sampling_rate=16000, - ).input_values - audio_lens = torch.tensor([cut.num_samples for cut in cuts], dtype=torch.int32) - - return { - "cuts": cuts, - "audio": audio, - "audio_lens": audio_lens, - } - - def _validate(self, cuts: CutSet) -> None: - validate(cuts) - assert all(cut.has_recording for cut in cuts) - - class HubertAsrDataset(torch.utils.data.Dataset): """ In this implementation, there will always be a single channel. @@ -94,7 +47,8 @@ class HubertAsrDataset(torch.utils.data.Dataset): padding_side="right", padding_value=0, do_normalize=True, - return_attention_mask=False, + return_attention_mask=True, + feature_extractor_type="Wav2Vec2FeatureExtractor", ) def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: @@ -148,7 +102,4 @@ if __name__ == "__main__": ) for batch_idx, batch in enumerate(dl): - print(batch["audio"]) - print(batch["audio_lens"]) - print(batch["supervisions"]["text"]) - print(batch["cuts"]) + break diff --git a/egs/librispeech/SSL/hubert/decode.py b/egs/librispeech/SSL/hubert/decode.py index 604d71453..7df3c2963 100644 --- a/egs/librispeech/SSL/hubert/decode.py +++ b/egs/librispeech/SSL/hubert/decode.py @@ -121,7 +121,7 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) -from train import add_model_arguments, get_model, get_params +from finetune import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( @@ -425,16 +425,10 @@ def decode_one_batch( the returned dict. """ device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 + audio = batch["audio"].to(device) + audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32) - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + encoder_out, encoder_out_lens = model.forward_encoder(audio, audio_lens) hyps = [] @@ -665,7 +659,7 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + cut_ids = [cut.id for cut in batch["cuts"]] hyps_dict = decode_one_batch( params=params, @@ -996,14 +990,23 @@ def main(): args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + dev_clean_cuts = librispeech.dev_clean_cuts() + dev_other_cuts = librispeech.dev_other_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts) + dev_other_dl = librispeech.test_dataloaders(dev_other_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["dev-clean", "dev-other"] + test_dl = [dev_clean_dl, dev_other_dl] + + # test_clean_cuts = librispeech.test_clean_cuts() + # test_other_cuts = librispeech.test_other_cuts() + + # test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + # test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + # test_sets = ["test-clean", "test-other"] + # test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 0c0095f9f..ad0ae4199 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -31,8 +31,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" --start-epoch 1 \ --use-fp16 0 \ --exp-dir hubert/exp \ - --full-libri 1 \ - --max-duration 80 + --full-libri 0 \ + --max-duration 200 It supports finetuning with: - transducer loss (default), with `--use-transducer True --use-ctc False` @@ -63,7 +63,6 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam -from scaling import ScheduledFloat from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -216,17 +215,17 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--mask-feature-length", type=int, - default=10, + default=64, ) parser.add_argument( "--mask-feature-min-masks", type=int, - default=0, + default=2, ) parser.add_argument( "--mask-feature-prob", type=float, - default=0.0, + default=0.5, ) parser.add_argument( "--mask-time-length", @@ -236,12 +235,12 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--mask-time-min-masks", type=int, - default=2, + default=10, ) parser.add_argument( "--mask-time-prob", type=float, - default=0.05, + default=0.65, ) parser.add_argument( "--num-attention-heads", @@ -361,7 +360,6 @@ def get_parser(): parser.add_argument( "--pretrained-dir", type=str, - default="download/hubert-base-ls960", help="""The pretrained model dir. It specifies the directory where the pretrained checkpoint is saved.""", ) @@ -374,7 +372,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", type=float, default=0.0005, help="The base learning rate." + "--base-lr", type=float, default=0.001, help="The base learning rate." ) parser.add_argument( @@ -608,40 +606,43 @@ def _get_feat_extract_output_lengths( def get_encoder_model(params: AttributeDict) -> nn.Module: - config = HubertConfig( - hidden_size=params.hidden_size, - num_hidden_layers=params.num_hidden_layers, - num_attention_heads=params.num_attention_heads, - intermediate_size=params.intermediate_size, - hidden_act=params.hidden_act, - hidden_dropout=params.hidden_dropout, - activation_dropout=params.activation_dropout, - attention_dropout=params.attention_dropout, - feat_proj_layer_norm=params.feat_proj_layer_norm, - feat_proj_dropout=params.feat_proj_dropout, - final_dropout=params.final_dropout, - layerdrop=params.layerdrop, - initializer_range=params.initializer_range, - layer_norm_eps=params.layer_norm_eps, - feat_extract_norm=params.feat_extract_norm, - feat_extract_activation=params.feat_extract_activation, - conv_dim=_to_int_tuple(params.conv_dim), - conv_stride=_to_int_tuple(params.conv_stride), - conv_kernel=_to_int_tuple(params.conv_kernel), - conv_bias=params.conv_bias, - num_conv_pos_embeddings=params.num_conv_pos_embeddings, - num_conv_pos_embedding_groups=params.num_conv_pos_embedding_groups, - do_stable_layer_norm=params.do_stable_layer_norm, - apply_spec_augment=params.apply_spec_augment, - mask_time_prob=params.mask_time_prob, - mask_time_length=params.mask_time_length, - mask_time_min_masks=params.mask_time_min_masks, - mask_feature_prob=params.mask_feature_prob, - mask_feature_length=params.mask_feature_length, - mask_feature_min_masks=params.mask_feature_min_masks, - ) - - encoder = HubertModel(config) + if hasattr(params, "pretrained_dir"): + logging.info(f"Loading {params.pretrained_dir}") + encoder = HubertModel.from_pretrained(params.pretrained_dir) + else: + config = HubertConfig( + hidden_size=params.hidden_size, + num_hidden_layers=params.num_hidden_layers, + num_attention_heads=params.num_attention_heads, + intermediate_size=params.intermediate_size, + hidden_act=params.hidden_act, + hidden_dropout=params.hidden_dropout, + activation_dropout=params.activation_dropout, + attention_dropout=params.attention_dropout, + feat_proj_layer_norm=params.feat_proj_layer_norm, + feat_proj_dropout=params.feat_proj_dropout, + final_dropout=params.final_dropout, + layerdrop=params.layerdrop, + initializer_range=params.initializer_range, + layer_norm_eps=params.layer_norm_eps, + feat_extract_norm=params.feat_extract_norm, + feat_extract_activation=params.feat_extract_activation, + conv_dim=_to_int_tuple(params.conv_dim), + conv_stride=_to_int_tuple(params.conv_stride), + conv_kernel=_to_int_tuple(params.conv_kernel), + conv_bias=params.conv_bias, + num_conv_pos_embeddings=params.num_conv_pos_embeddings, + num_conv_pos_embedding_groups=params.num_conv_pos_embedding_groups, + do_stable_layer_norm=params.do_stable_layer_norm, + apply_spec_augment=params.apply_spec_augment, + mask_time_prob=params.mask_time_prob, + mask_time_length=params.mask_time_length, + mask_time_min_masks=params.mask_time_min_masks, + mask_feature_prob=params.mask_feature_prob, + mask_feature_length=params.mask_feature_length, + mask_feature_min_masks=params.mask_feature_min_masks, + ) + encoder = HubertModel(config) return encoder @@ -731,8 +732,6 @@ def load_checkpoint_if_available( elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: - logging.info(f"Loading {params.pretrained_dir}") - model.encoder = HubertModel.from_pretrained(params.pretrained_dir) return None assert filename.is_file(), f"{filename} does not exist!" @@ -839,7 +838,7 @@ def compute_loss( """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device audio = batch["audio"].to(device) - audio_lens = batch["audio_lens"].to(device) + audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32) batch_idx_train = params.batch_idx_train warm_step = params.warm_step @@ -1113,7 +1112,10 @@ def train_one_epoch( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % (params.valid_interval * params.accum_grad) == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index ce203e3e0..5484e9da5 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -32,7 +32,7 @@ class AsrModel(nn.Module): encoder, decoder: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None, - encoder_dim: int = 1024, + encoder_dim: int = 768, decoder_dim: int = 512, vocab_size: int = 500, use_transducer: bool = True, @@ -111,7 +111,7 @@ class AsrModel(nn.Module): A 2-D tensor of shape (N, T). x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. + w/wo padding. Returns: encoder_out: @@ -119,12 +119,10 @@ class AsrModel(nn.Module): encoder_out_lens: Encoder output lengths, of shape (N,). """ + encoder_out = self.encoder(x).last_hidden_state encoder_out_lens = self.encoder._get_feat_extract_output_lengths(x_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) - src_key_padding_mask = make_pad_mask(x_lens) - encoder_out = self.encoder(x, src_key_padding_mask).last_hidden_state - return encoder_out, encoder_out_lens def forward_ctc( @@ -278,7 +276,7 @@ class AsrModel(nn.Module): A 2-D tensor of shape (N, T). x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. + w/wo padding. y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance.