This commit is contained in:
Yifan Yang 2024-01-01 20:09:22 +08:00 committed by GitHub
parent a01a4231c4
commit 8023493029
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 123 deletions

View File

@ -25,53 +25,6 @@ from torch.utils.data.dataloader import default_collate
from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2FeatureExtractor
class HubertDataset(torch.utils.data.Dataset):
"""
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'audio_lens': (B, ) int tensor
}
"""
def __init__(self, collate: bool = True) -> None:
super().__init__()
self.feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_side="right",
padding_value=0.0,
do_normalize=True,
return_attention_mask=True,
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
audio, _ = read_audio_from_cuts(cuts, return_tensors=False)
audio = self.feature_extractor(
audio,
padding=True,
return_tensors="pt",
sampling_rate=16000,
).input_values
audio_lens = torch.tensor([cut.num_samples for cut in cuts], dtype=torch.int32)
return {
"cuts": cuts,
"audio": audio,
"audio_lens": audio_lens,
}
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)
class HubertAsrDataset(torch.utils.data.Dataset): class HubertAsrDataset(torch.utils.data.Dataset):
""" """
In this implementation, there will always be a single channel. In this implementation, there will always be a single channel.
@ -94,7 +47,8 @@ class HubertAsrDataset(torch.utils.data.Dataset):
padding_side="right", padding_side="right",
padding_value=0, padding_value=0,
do_normalize=True, do_normalize=True,
return_attention_mask=False, return_attention_mask=True,
feature_extractor_type="Wav2Vec2FeatureExtractor",
) )
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
@ -148,7 +102,4 @@ if __name__ == "__main__":
) )
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
print(batch["audio"]) break
print(batch["audio_lens"])
print(batch["supervisions"]["text"])
print(batch["cuts"])

View File

@ -121,7 +121,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from train import add_model_arguments, get_model, get_params from finetune import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -425,16 +425,10 @@ def decode_one_batch(
the returned dict. the returned dict.
""" """
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] audio = batch["audio"].to(device)
assert feature.ndim == 3 audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32)
feature = feature.to(device) encoder_out, encoder_out_lens = model.forward_encoder(audio, audio_lens)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
hyps = [] hyps = []
@ -665,7 +659,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["cuts"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -996,14 +990,23 @@ def main():
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() dev_clean_cuts = librispeech.dev_clean_cuts()
test_other_cuts = librispeech.test_other_cuts() dev_other_cuts = librispeech.dev_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) dev_other_dl = librispeech.test_dataloaders(dev_other_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["dev-clean", "dev-other"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [dev_clean_dl, dev_other_dl]
# test_clean_cuts = librispeech.test_clean_cuts()
# test_other_cuts = librispeech.test_other_cuts()
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
# test_sets = ["test-clean", "test-other"]
# test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -31,8 +31,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 0 \ --use-fp16 0 \
--exp-dir hubert/exp \ --exp-dir hubert/exp \
--full-libri 1 \ --full-libri 0 \
--max-duration 80 --max-duration 200
It supports finetuning with: It supports finetuning with:
- transducer loss (default), with `--use-transducer True --use-ctc False` - transducer loss (default), with `--use-transducer True --use-ctc False`
@ -63,7 +63,6 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -216,17 +215,17 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--mask-feature-length", "--mask-feature-length",
type=int, type=int,
default=10, default=64,
) )
parser.add_argument( parser.add_argument(
"--mask-feature-min-masks", "--mask-feature-min-masks",
type=int, type=int,
default=0, default=2,
) )
parser.add_argument( parser.add_argument(
"--mask-feature-prob", "--mask-feature-prob",
type=float, type=float,
default=0.0, default=0.5,
) )
parser.add_argument( parser.add_argument(
"--mask-time-length", "--mask-time-length",
@ -236,12 +235,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--mask-time-min-masks", "--mask-time-min-masks",
type=int, type=int,
default=2, default=10,
) )
parser.add_argument( parser.add_argument(
"--mask-time-prob", "--mask-time-prob",
type=float, type=float,
default=0.05, default=0.65,
) )
parser.add_argument( parser.add_argument(
"--num-attention-heads", "--num-attention-heads",
@ -361,7 +360,6 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--pretrained-dir", "--pretrained-dir",
type=str, type=str,
default="download/hubert-base-ls960",
help="""The pretrained model dir. help="""The pretrained model dir.
It specifies the directory where the pretrained checkpoint is saved.""", It specifies the directory where the pretrained checkpoint is saved.""",
) )
@ -374,7 +372,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.0005, help="The base learning rate." "--base-lr", type=float, default=0.001, help="The base learning rate."
) )
parser.add_argument( parser.add_argument(
@ -608,40 +606,43 @@ def _get_feat_extract_output_lengths(
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
config = HubertConfig( if hasattr(params, "pretrained_dir"):
hidden_size=params.hidden_size, logging.info(f"Loading {params.pretrained_dir}")
num_hidden_layers=params.num_hidden_layers, encoder = HubertModel.from_pretrained(params.pretrained_dir)
num_attention_heads=params.num_attention_heads, else:
intermediate_size=params.intermediate_size, config = HubertConfig(
hidden_act=params.hidden_act, hidden_size=params.hidden_size,
hidden_dropout=params.hidden_dropout, num_hidden_layers=params.num_hidden_layers,
activation_dropout=params.activation_dropout, num_attention_heads=params.num_attention_heads,
attention_dropout=params.attention_dropout, intermediate_size=params.intermediate_size,
feat_proj_layer_norm=params.feat_proj_layer_norm, hidden_act=params.hidden_act,
feat_proj_dropout=params.feat_proj_dropout, hidden_dropout=params.hidden_dropout,
final_dropout=params.final_dropout, activation_dropout=params.activation_dropout,
layerdrop=params.layerdrop, attention_dropout=params.attention_dropout,
initializer_range=params.initializer_range, feat_proj_layer_norm=params.feat_proj_layer_norm,
layer_norm_eps=params.layer_norm_eps, feat_proj_dropout=params.feat_proj_dropout,
feat_extract_norm=params.feat_extract_norm, final_dropout=params.final_dropout,
feat_extract_activation=params.feat_extract_activation, layerdrop=params.layerdrop,
conv_dim=_to_int_tuple(params.conv_dim), initializer_range=params.initializer_range,
conv_stride=_to_int_tuple(params.conv_stride), layer_norm_eps=params.layer_norm_eps,
conv_kernel=_to_int_tuple(params.conv_kernel), feat_extract_norm=params.feat_extract_norm,
conv_bias=params.conv_bias, feat_extract_activation=params.feat_extract_activation,
num_conv_pos_embeddings=params.num_conv_pos_embeddings, conv_dim=_to_int_tuple(params.conv_dim),
num_conv_pos_embedding_groups=params.num_conv_pos_embedding_groups, conv_stride=_to_int_tuple(params.conv_stride),
do_stable_layer_norm=params.do_stable_layer_norm, conv_kernel=_to_int_tuple(params.conv_kernel),
apply_spec_augment=params.apply_spec_augment, conv_bias=params.conv_bias,
mask_time_prob=params.mask_time_prob, num_conv_pos_embeddings=params.num_conv_pos_embeddings,
mask_time_length=params.mask_time_length, num_conv_pos_embedding_groups=params.num_conv_pos_embedding_groups,
mask_time_min_masks=params.mask_time_min_masks, do_stable_layer_norm=params.do_stable_layer_norm,
mask_feature_prob=params.mask_feature_prob, apply_spec_augment=params.apply_spec_augment,
mask_feature_length=params.mask_feature_length, mask_time_prob=params.mask_time_prob,
mask_feature_min_masks=params.mask_feature_min_masks, mask_time_length=params.mask_time_length,
) mask_time_min_masks=params.mask_time_min_masks,
mask_feature_prob=params.mask_feature_prob,
encoder = HubertModel(config) mask_feature_length=params.mask_feature_length,
mask_feature_min_masks=params.mask_feature_min_masks,
)
encoder = HubertModel(config)
return encoder return encoder
@ -731,8 +732,6 @@ def load_checkpoint_if_available(
elif params.start_epoch > 1: elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else: else:
logging.info(f"Loading {params.pretrained_dir}")
model.encoder = HubertModel.from_pretrained(params.pretrained_dir)
return None return None
assert filename.is_file(), f"{filename} does not exist!" assert filename.is_file(), f"{filename} does not exist!"
@ -839,7 +838,7 @@ def compute_loss(
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
audio = batch["audio"].to(device) audio = batch["audio"].to(device)
audio_lens = batch["audio_lens"].to(device) audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32)
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
warm_step = params.warm_step warm_step = params.warm_step
@ -1113,7 +1112,10 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
) )
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if (
batch_idx % (params.valid_interval * params.accum_grad) == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,

View File

@ -32,7 +32,7 @@ class AsrModel(nn.Module):
encoder, encoder,
decoder: Optional[nn.Module] = None, decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None,
encoder_dim: int = 1024, encoder_dim: int = 768,
decoder_dim: int = 512, decoder_dim: int = 512,
vocab_size: int = 500, vocab_size: int = 500,
use_transducer: bool = True, use_transducer: bool = True,
@ -111,7 +111,7 @@ class AsrModel(nn.Module):
A 2-D tensor of shape (N, T). A 2-D tensor of shape (N, T).
x_lens: x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x` A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding. w/wo padding.
Returns: Returns:
encoder_out: encoder_out:
@ -119,12 +119,10 @@ class AsrModel(nn.Module):
encoder_out_lens: encoder_out_lens:
Encoder output lengths, of shape (N,). Encoder output lengths, of shape (N,).
""" """
encoder_out = self.encoder(x).last_hidden_state
encoder_out_lens = self.encoder._get_feat_extract_output_lengths(x_lens) encoder_out_lens = self.encoder._get_feat_extract_output_lengths(x_lens)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
src_key_padding_mask = make_pad_mask(x_lens)
encoder_out = self.encoder(x, src_key_padding_mask).last_hidden_state
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens
def forward_ctc( def forward_ctc(
@ -278,7 +276,7 @@ class AsrModel(nn.Module):
A 2-D tensor of shape (N, T). A 2-D tensor of shape (N, T).
x_lens: x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x` A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding. w/wo padding.
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.