mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update
This commit is contained in:
parent
a01a4231c4
commit
8023493029
@ -25,53 +25,6 @@ from torch.utils.data.dataloader import default_collate
|
||||
from transformers import Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
class HubertDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
In this implementation, there will always be a single channel.
|
||||
|
||||
Returns:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'audio': (B x NumSamples) float tensor
|
||||
'audio_lens': (B, ) int tensor
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, collate: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor(
|
||||
feature_size=1,
|
||||
sampling_rate=16000,
|
||||
padding_side="right",
|
||||
padding_value=0.0,
|
||||
do_normalize=True,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
self._validate(cuts)
|
||||
audio, _ = read_audio_from_cuts(cuts, return_tensors=False)
|
||||
audio = self.feature_extractor(
|
||||
audio,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16000,
|
||||
).input_values
|
||||
audio_lens = torch.tensor([cut.num_samples for cut in cuts], dtype=torch.int32)
|
||||
|
||||
return {
|
||||
"cuts": cuts,
|
||||
"audio": audio,
|
||||
"audio_lens": audio_lens,
|
||||
}
|
||||
|
||||
def _validate(self, cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
assert all(cut.has_recording for cut in cuts)
|
||||
|
||||
|
||||
class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
In this implementation, there will always be a single channel.
|
||||
@ -94,7 +47,8 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
padding_side="right",
|
||||
padding_value=0,
|
||||
do_normalize=True,
|
||||
return_attention_mask=False,
|
||||
return_attention_mask=True,
|
||||
feature_extractor_type="Wav2Vec2FeatureExtractor",
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
@ -148,7 +102,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
print(batch["audio"])
|
||||
print(batch["audio_lens"])
|
||||
print(batch["supervisions"]["text"])
|
||||
print(batch["cuts"])
|
||||
break
|
||||
|
@ -121,7 +121,7 @@ from beam_search import (
|
||||
modified_beam_search_lm_shallow_fusion,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from finetune import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
@ -425,16 +425,10 @@ def decode_one_batch(
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
audio = batch["audio"].to(device)
|
||||
audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32)
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(audio, audio_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
@ -665,7 +659,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -996,14 +990,23 @@ def main():
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||
dev_other_cuts = librispeech.dev_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
dev_clean_dl = librispeech.test_dataloaders(dev_clean_cuts)
|
||||
dev_other_dl = librispeech.test_dataloaders(dev_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
test_sets = ["dev-clean", "dev-other"]
|
||||
test_dl = [dev_clean_dl, dev_other_dl]
|
||||
|
||||
# test_clean_cuts = librispeech.test_clean_cuts()
|
||||
# test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
# test_sets = ["test-clean", "test-other"]
|
||||
# test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
@ -31,8 +31,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
--start-epoch 1 \
|
||||
--use-fp16 0 \
|
||||
--exp-dir hubert/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 80
|
||||
--full-libri 0 \
|
||||
--max-duration 200
|
||||
|
||||
It supports finetuning with:
|
||||
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
||||
@ -63,7 +63,6 @@ from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -216,17 +215,17 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--mask-feature-length",
|
||||
type=int,
|
||||
default=10,
|
||||
default=64,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-feature-min-masks",
|
||||
type=int,
|
||||
default=0,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-feature-prob",
|
||||
type=float,
|
||||
default=0.0,
|
||||
default=0.5,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-time-length",
|
||||
@ -236,12 +235,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--mask-time-min-masks",
|
||||
type=int,
|
||||
default=2,
|
||||
default=10,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-time-prob",
|
||||
type=float,
|
||||
default=0.05,
|
||||
default=0.65,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-attention-heads",
|
||||
@ -361,7 +360,6 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--pretrained-dir",
|
||||
type=str,
|
||||
default="download/hubert-base-ls960",
|
||||
help="""The pretrained model dir.
|
||||
It specifies the directory where the pretrained checkpoint is saved.""",
|
||||
)
|
||||
@ -374,7 +372,7 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.0005, help="The base learning rate."
|
||||
"--base-lr", type=float, default=0.001, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -608,6 +606,10 @@ def _get_feat_extract_output_lengths(
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
if hasattr(params, "pretrained_dir"):
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
encoder = HubertModel.from_pretrained(params.pretrained_dir)
|
||||
else:
|
||||
config = HubertConfig(
|
||||
hidden_size=params.hidden_size,
|
||||
num_hidden_layers=params.num_hidden_layers,
|
||||
@ -640,7 +642,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
mask_feature_length=params.mask_feature_length,
|
||||
mask_feature_min_masks=params.mask_feature_min_masks,
|
||||
)
|
||||
|
||||
encoder = HubertModel(config)
|
||||
return encoder
|
||||
|
||||
@ -731,8 +732,6 @@ def load_checkpoint_if_available(
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
model.encoder = HubertModel.from_pretrained(params.pretrained_dir)
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
@ -839,7 +838,7 @@ def compute_loss(
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
audio = batch["audio"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
audio_lens = torch.full(audio.shape[:1], audio.shape[1], dtype=torch.int32)
|
||||
|
||||
batch_idx_train = params.batch_idx_train
|
||||
warm_step = params.warm_step
|
||||
@ -1113,7 +1112,10 @@ def train_one_epoch(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if (
|
||||
batch_idx % (params.valid_interval * params.accum_grad) == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
|
@ -32,7 +32,7 @@ class AsrModel(nn.Module):
|
||||
encoder,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 1024,
|
||||
encoder_dim: int = 768,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
@ -111,7 +111,7 @@ class AsrModel(nn.Module):
|
||||
A 2-D tensor of shape (N, T).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
w/wo padding.
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
@ -119,12 +119,10 @@ class AsrModel(nn.Module):
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
encoder_out = self.encoder(x).last_hidden_state
|
||||
encoder_out_lens = self.encoder._get_feat_extract_output_lengths(x_lens)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
encoder_out = self.encoder(x, src_key_padding_mask).last_hidden_state
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
@ -278,7 +276,7 @@ class AsrModel(nn.Module):
|
||||
A 2-D tensor of shape (N, T).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
w/wo padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
|
Loading…
x
Reference in New Issue
Block a user