mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
decoder hubert model
This commit is contained in:
parent
9d48f1ce7d
commit
fb9c0c3971
@ -34,7 +34,7 @@ from lhotse.dataset import (
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@ -192,6 +192,13 @@ class LibriSpeechAsrDataModule:
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
@ -263,6 +270,9 @@ class LibriSpeechAsrDataModule:
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=AudioSamples()
|
||||
if self.args.input_strategy == "AudioSamples"
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
@ -371,7 +381,7 @@ class LibriSpeechAsrDataModule:
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = BucketingSampler(
|
||||
|
211
egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py
Executable file
211
egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py
Executable file
@ -0,0 +1,211 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq.data.data_utils import post_process
|
||||
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from hubert_utils import (
|
||||
extract_layers_result,
|
||||
load_hubert_model,
|
||||
get_parser,
|
||||
vq_config,
|
||||
)
|
||||
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
model: nn.Module,
|
||||
processor,
|
||||
params,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
model:
|
||||
The neural model.
|
||||
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
|
||||
w2v_model = model.w2v_encoder.w2v_model
|
||||
layer_results = extract_layers_result(
|
||||
w2v_model, batch=batch, device=params.device
|
||||
)
|
||||
|
||||
encoder_out = w2v_model.encoder.layer_norm(
|
||||
layer_results[params.total_layers - 1][0]
|
||||
)
|
||||
encoder_out = model.w2v_encoder.proj(encoder_out.transpose(0, 1))
|
||||
|
||||
toks = encoder_out.argmax(dim=-1)
|
||||
blank = 0
|
||||
toks = [tok.unique_consecutive() for tok in toks]
|
||||
hyps = [processor.string(tok[tok != blank].int().cpu()) for tok in toks]
|
||||
hyps = [post_process(hyp, "letter") for hyp in hyps]
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
assert len(hyps) == len(texts)
|
||||
|
||||
for hyp_text, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
hyp_words = hyp_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results["ctc_greedy_search"].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 20 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"hubert-recogs-{test_set_name}-{key}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"hubert-errs-{test_set_name}-{key}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"hubert-wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = AttributeDict()
|
||||
params.update(vars(args))
|
||||
params.update(vq_config)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-ctc_greedy_search/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
params.device = device
|
||||
|
||||
model, processor = load_hubert_model(params)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(params)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
model=model,
|
||||
processor=processor,
|
||||
params=params,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (author: Liyong Guo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import (
|
||||
checkpoint_utils,
|
||||
tasks,
|
||||
utils,
|
||||
)
|
||||
from fairseq.models.hubert.hubert import HubertModel
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
vq_config = {
|
||||
# parameters about hubert model inference.
|
||||
"model_dir": "./vq_pruned_transducer_stateless2/exp/hubert_models/",
|
||||
"model_id": "hubert_xtralarge_ll60k_finetune_ls960",
|
||||
"input_strategy": "AudioSamples",
|
||||
"enable_spec_aug": False,
|
||||
"enable_musan": False,
|
||||
"total_layers": 48,
|
||||
"memory_embedding_dim": 1280,
|
||||
# parameters about quantizer.
|
||||
"num_utts": 100,
|
||||
"memory_layer": 36,
|
||||
"memory_dir": "./vq_pruned_transducer_stateless2/exp/mem/",
|
||||
"bytes_per_frame": 8,
|
||||
"refine_iter": 5,
|
||||
"enable_refine": True,
|
||||
# parameters about extracted codebook index.
|
||||
"data_dir": "./data/",
|
||||
}
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--job-idx",
|
||||
type=int,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-splits",
|
||||
type=int,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantizer-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="quantizer_id" "Manully set this incase of mistake.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--refine-iter",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="number of refine iterations when extracting codebook indices",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ori-manifest-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_hubert_model(params):
|
||||
cfg_task = OmegaConf.create(
|
||||
{
|
||||
"_name": "hubert_pretraining",
|
||||
"single_target": True,
|
||||
"fine_tuning": True,
|
||||
"data": params.model_dir,
|
||||
}
|
||||
)
|
||||
model_path = Path(params.model_dir) / (params.model_id + ".pt")
|
||||
task = tasks.setup_task(cfg_task)
|
||||
processor = task.target_dictionary
|
||||
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
||||
utils.split_paths(str(model_path), separator="\\"),
|
||||
arg_overrides={},
|
||||
strict=True,
|
||||
suffix="",
|
||||
num_shards=1,
|
||||
)
|
||||
model = models[0]
|
||||
model.to(params.device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
return model, processor
|
||||
|
||||
|
||||
# Modified from HubertModel.forward to extract all middle layers output
|
||||
def extract_layers_result(
|
||||
model: HubertModel,
|
||||
batch: Dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
features = batch["inputs"]
|
||||
|
||||
# corresponding task.normalize in fairseq
|
||||
features = torch.nn.functional.layer_norm(features, features.shape)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
num_samples = supervisions["num_samples"]
|
||||
B, T = features.shape
|
||||
padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape(
|
||||
[-1, 1]
|
||||
)
|
||||
|
||||
padding_mask = padding_mask.to(device)
|
||||
features = features.to(device)
|
||||
|
||||
features = model.forward_features(features)
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
features = model.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = model.forward_padding_mask(features, padding_mask)
|
||||
|
||||
if model.post_extract_proj is not None:
|
||||
features = model.post_extract_proj(features)
|
||||
|
||||
_, layer_results = model.encoder(
|
||||
features,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
return layer_results
|
Loading…
x
Reference in New Issue
Block a user