mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
update
This commit is contained in:
parent
a01a4231c4
commit
8023493029
@ -25,53 +25,6 @@ from torch.utils.data.dataloader import default_collate
|
|||||||
from transformers import Wav2Vec2FeatureExtractor
|
from transformers import Wav2Vec2FeatureExtractor
|
||||||
|
|
||||||
|
|
||||||
class HubertDataset(torch.utils.data.Dataset):
|
|
||||||
"""
|
|
||||||
In this implementation, there will always be a single channel.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
{
|
|
||||||
'audio': (B x NumSamples) float tensor
|
|
||||||
'audio_lens': (B, ) int tensor
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, collate: bool = True) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.feature_extractor = Wav2Vec2FeatureExtractor(
|
|
||||||
feature_size=1,
|
|
||||||
sampling_rate=16000,
|
|
||||||
padding_side="right",
|
|
||||||
padding_value=0.0,
|
|
||||||
do_normalize=True,
|
|
||||||
return_attention_mask=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
|
||||||
self._validate(cuts)
|
|
||||||
audio, _ = read_audio_from_cuts(cuts, return_tensors=False)
|
|
||||||
audio = self.feature_extractor(
|
|
||||||
audio,
|
|
||||||
padding=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
sampling_rate=16000,
|
|
||||||
).input_values
|
|
||||||
audio_lens = torch.tensor([cut.num_samples for cut in cuts], dtype=torch.int32)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cuts": cuts,
|
|
||||||
"audio": audio,
|
|
||||||
"audio_lens": audio_lens,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _validate(self, cuts: CutSet) -> None:
|
|
||||||
validate(cuts)
|
|
||||||
assert all(cut.has_recording for cut in cuts)
|
|
||||||
|
|
||||||
|
|
||||||
class HubertAsrDataset(torch.utils.data.Dataset):
|
class HubertAsrDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
In this implementation, there will always be a single channel.
|
In this implementation, there will always be a single channel.
|
||||||
@ -94,7 +47,8 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
padding_value=0,
|
padding_value=0,
|
||||||
do_normalize=True,
|
do_normalize=True,
|
||||||
return_attention_mask=False,
|
return_attention_mask=True,
|
||||||
|
feature_extractor_type="Wav2Vec2FeatureExtractor",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||||
@ -148,7 +102,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
print(batch["audio"])
|
break
|
||||||
print(batch["audio_lens"])
|
|
||||||
print(batch["supervisions"]["text"])
|
|
||||||
print(batch["cuts"])
|
|
||||||
|
@ -121,7 +121,7 @@ from beam_search import (
|
|||||||
modified_beam_search_lm_shallow_fusion,
|
modified_beam_search_lm_shallow_fusion,
|
||||||
modified_beam_search_LODR,
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_model, get_params
|
from finetune import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -425,16 +425,10 @@ def decode_one_batch(
|
|||||||
the returned dict.
|
the returned dict.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
audio = batch["audio"].to(device)
|
||||||
assert feature.ndim == 3
|
audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32)
|
||||||
|
|
||||||
feature = feature.to(device)
|
encoder_out, encoder_out_lens = model.forward_encoder(audio, audio_lens)
|
||||||
# at entry, feature is (N, T, C)
|
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
@ -665,7 +659,7 @@ def decode_dataset(
|
|||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
@ -996,14 +990,23 @@ def main():
|
|||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
dev_other_cuts = librispeech.dev_other_cuts()
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts)
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
dev_other_dl = librispeech.test_dataloaders(dev_other_cuts)
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["dev-clean", "dev-other"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [dev_clean_dl, dev_other_dl]
|
||||||
|
|
||||||
|
# test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
# test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
|
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
|
# test_sets = ["test-clean", "test-other"]
|
||||||
|
# test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
@ -31,8 +31,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 0 \
|
--use-fp16 0 \
|
||||||
--exp-dir hubert/exp \
|
--exp-dir hubert/exp \
|
||||||
--full-libri 1 \
|
--full-libri 0 \
|
||||||
--max-duration 80
|
--max-duration 200
|
||||||
|
|
||||||
It supports finetuning with:
|
It supports finetuning with:
|
||||||
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
||||||
@ -63,7 +63,6 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -216,17 +215,17 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-feature-length",
|
"--mask-feature-length",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=64,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-feature-min-masks",
|
"--mask-feature-min-masks",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=2,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-feature-prob",
|
"--mask-feature-prob",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.5,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-time-length",
|
"--mask-time-length",
|
||||||
@ -236,12 +235,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-time-min-masks",
|
"--mask-time-min-masks",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=10,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-time-prob",
|
"--mask-time-prob",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.05,
|
default=0.65,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-attention-heads",
|
"--num-attention-heads",
|
||||||
@ -361,7 +360,6 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretrained-dir",
|
"--pretrained-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="download/hubert-base-ls960",
|
|
||||||
help="""The pretrained model dir.
|
help="""The pretrained model dir.
|
||||||
It specifies the directory where the pretrained checkpoint is saved.""",
|
It specifies the directory where the pretrained checkpoint is saved.""",
|
||||||
)
|
)
|
||||||
@ -374,7 +372,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=0.0005, help="The base learning rate."
|
"--base-lr", type=float, default=0.001, help="The base learning rate."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -608,40 +606,43 @@ def _get_feat_extract_output_lengths(
|
|||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
config = HubertConfig(
|
if hasattr(params, "pretrained_dir"):
|
||||||
hidden_size=params.hidden_size,
|
logging.info(f"Loading {params.pretrained_dir}")
|
||||||
num_hidden_layers=params.num_hidden_layers,
|
encoder = HubertModel.from_pretrained(params.pretrained_dir)
|
||||||
num_attention_heads=params.num_attention_heads,
|
else:
|
||||||
intermediate_size=params.intermediate_size,
|
config = HubertConfig(
|
||||||
hidden_act=params.hidden_act,
|
hidden_size=params.hidden_size,
|
||||||
hidden_dropout=params.hidden_dropout,
|
num_hidden_layers=params.num_hidden_layers,
|
||||||
activation_dropout=params.activation_dropout,
|
num_attention_heads=params.num_attention_heads,
|
||||||
attention_dropout=params.attention_dropout,
|
intermediate_size=params.intermediate_size,
|
||||||
feat_proj_layer_norm=params.feat_proj_layer_norm,
|
hidden_act=params.hidden_act,
|
||||||
feat_proj_dropout=params.feat_proj_dropout,
|
hidden_dropout=params.hidden_dropout,
|
||||||
final_dropout=params.final_dropout,
|
activation_dropout=params.activation_dropout,
|
||||||
layerdrop=params.layerdrop,
|
attention_dropout=params.attention_dropout,
|
||||||
initializer_range=params.initializer_range,
|
feat_proj_layer_norm=params.feat_proj_layer_norm,
|
||||||
layer_norm_eps=params.layer_norm_eps,
|
feat_proj_dropout=params.feat_proj_dropout,
|
||||||
feat_extract_norm=params.feat_extract_norm,
|
final_dropout=params.final_dropout,
|
||||||
feat_extract_activation=params.feat_extract_activation,
|
layerdrop=params.layerdrop,
|
||||||
conv_dim=_to_int_tuple(params.conv_dim),
|
initializer_range=params.initializer_range,
|
||||||
conv_stride=_to_int_tuple(params.conv_stride),
|
layer_norm_eps=params.layer_norm_eps,
|
||||||
conv_kernel=_to_int_tuple(params.conv_kernel),
|
feat_extract_norm=params.feat_extract_norm,
|
||||||
conv_bias=params.conv_bias,
|
feat_extract_activation=params.feat_extract_activation,
|
||||||
num_conv_pos_embeddings=params.num_conv_pos_embeddings,
|
conv_dim=_to_int_tuple(params.conv_dim),
|
||||||
num_conv_pos_embedding_groups=params.num_conv_pos_embedding_groups,
|
conv_stride=_to_int_tuple(params.conv_stride),
|
||||||
do_stable_layer_norm=params.do_stable_layer_norm,
|
conv_kernel=_to_int_tuple(params.conv_kernel),
|
||||||
apply_spec_augment=params.apply_spec_augment,
|
conv_bias=params.conv_bias,
|
||||||
mask_time_prob=params.mask_time_prob,
|
num_conv_pos_embeddings=params.num_conv_pos_embeddings,
|
||||||
mask_time_length=params.mask_time_length,
|
num_conv_pos_embedding_groups=params.num_conv_pos_embedding_groups,
|
||||||
mask_time_min_masks=params.mask_time_min_masks,
|
do_stable_layer_norm=params.do_stable_layer_norm,
|
||||||
mask_feature_prob=params.mask_feature_prob,
|
apply_spec_augment=params.apply_spec_augment,
|
||||||
mask_feature_length=params.mask_feature_length,
|
mask_time_prob=params.mask_time_prob,
|
||||||
mask_feature_min_masks=params.mask_feature_min_masks,
|
mask_time_length=params.mask_time_length,
|
||||||
)
|
mask_time_min_masks=params.mask_time_min_masks,
|
||||||
|
mask_feature_prob=params.mask_feature_prob,
|
||||||
encoder = HubertModel(config)
|
mask_feature_length=params.mask_feature_length,
|
||||||
|
mask_feature_min_masks=params.mask_feature_min_masks,
|
||||||
|
)
|
||||||
|
encoder = HubertModel(config)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
@ -731,8 +732,6 @@ def load_checkpoint_if_available(
|
|||||||
elif params.start_epoch > 1:
|
elif params.start_epoch > 1:
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
else:
|
else:
|
||||||
logging.info(f"Loading {params.pretrained_dir}")
|
|
||||||
model.encoder = HubertModel.from_pretrained(params.pretrained_dir)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
assert filename.is_file(), f"{filename} does not exist!"
|
assert filename.is_file(), f"{filename} does not exist!"
|
||||||
@ -839,7 +838,7 @@ def compute_loss(
|
|||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
audio = batch["audio"].to(device)
|
audio = batch["audio"].to(device)
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32)
|
||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
@ -1113,7 +1112,10 @@ def train_one_epoch(
|
|||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if (
|
||||||
|
batch_idx % (params.valid_interval * params.accum_grad) == 0
|
||||||
|
and not params.print_diagnostics
|
||||||
|
):
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -32,7 +32,7 @@ class AsrModel(nn.Module):
|
|||||||
encoder,
|
encoder,
|
||||||
decoder: Optional[nn.Module] = None,
|
decoder: Optional[nn.Module] = None,
|
||||||
joiner: Optional[nn.Module] = None,
|
joiner: Optional[nn.Module] = None,
|
||||||
encoder_dim: int = 1024,
|
encoder_dim: int = 768,
|
||||||
decoder_dim: int = 512,
|
decoder_dim: int = 512,
|
||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
use_transducer: bool = True,
|
use_transducer: bool = True,
|
||||||
@ -111,7 +111,7 @@ class AsrModel(nn.Module):
|
|||||||
A 2-D tensor of shape (N, T).
|
A 2-D tensor of shape (N, T).
|
||||||
x_lens:
|
x_lens:
|
||||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
before padding.
|
w/wo padding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
@ -119,12 +119,10 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens:
|
encoder_out_lens:
|
||||||
Encoder output lengths, of shape (N,).
|
Encoder output lengths, of shape (N,).
|
||||||
"""
|
"""
|
||||||
|
encoder_out = self.encoder(x).last_hidden_state
|
||||||
encoder_out_lens = self.encoder._get_feat_extract_output_lengths(x_lens)
|
encoder_out_lens = self.encoder._get_feat_extract_output_lengths(x_lens)
|
||||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
|
||||||
encoder_out = self.encoder(x, src_key_padding_mask).last_hidden_state
|
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
def forward_ctc(
|
def forward_ctc(
|
||||||
@ -278,7 +276,7 @@ class AsrModel(nn.Module):
|
|||||||
A 2-D tensor of shape (N, T).
|
A 2-D tensor of shape (N, T).
|
||||||
x_lens:
|
x_lens:
|
||||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
before padding.
|
w/wo padding.
|
||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user