mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
decoder hubert model
This commit is contained in:
parent
9d48f1ce7d
commit
fb9c0c3971
@ -34,7 +34,7 @@ from lhotse.dataset import (
|
|||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -192,6 +192,13 @@ class LibriSpeechAsrDataModule:
|
|||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
@ -263,6 +270,9 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
|
input_strategy=AudioSamples()
|
||||||
|
if self.args.input_strategy == "AudioSamples"
|
||||||
|
else PrecomputedFeatures(),
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -371,7 +381,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else PrecomputedFeatures(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = BucketingSampler(
|
sampler = BucketingSampler(
|
||||||
|
211
egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py
Executable file
211
egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py
Executable file
@ -0,0 +1,211 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from fairseq.data.data_utils import post_process
|
||||||
|
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from hubert_utils import (
|
||||||
|
extract_layers_result,
|
||||||
|
load_hubert_model,
|
||||||
|
get_parser,
|
||||||
|
vq_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
model: nn.Module,
|
||||||
|
processor,
|
||||||
|
params,
|
||||||
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||||
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
|
||||||
|
w2v_model = model.w2v_encoder.w2v_model
|
||||||
|
layer_results = extract_layers_result(
|
||||||
|
w2v_model, batch=batch, device=params.device
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = w2v_model.encoder.layer_norm(
|
||||||
|
layer_results[params.total_layers - 1][0]
|
||||||
|
)
|
||||||
|
encoder_out = model.w2v_encoder.proj(encoder_out.transpose(0, 1))
|
||||||
|
|
||||||
|
toks = encoder_out.argmax(dim=-1)
|
||||||
|
blank = 0
|
||||||
|
toks = [tok.unique_consecutive() for tok in toks]
|
||||||
|
hyps = [processor.string(tok[tok != blank].int().cpu()) for tok in toks]
|
||||||
|
hyps = [post_process(hyp, "letter") for hyp in hyps]
|
||||||
|
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
|
||||||
|
for hyp_text, ref_text in zip(hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
hyp_words = hyp_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
results["ctc_greedy_search"].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 20 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.exp_dir / f"hubert-recogs-{test_set_name}-{key}.txt"
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.exp_dir / f"hubert-errs-{test_set_name}-{key}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Wrote detailed error stats to {}".format(errs_filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.exp_dir / f"hubert-wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = AttributeDict()
|
||||||
|
params.update(vars(args))
|
||||||
|
params.update(vq_config)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-ctc_greedy_search/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
model, processor = load_hubert_model(params)
|
||||||
|
|
||||||
|
librispeech = LibriSpeechAsrDataModule(params)
|
||||||
|
|
||||||
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
|
||||||
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (author: Liyong Guo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fairseq import (
|
||||||
|
checkpoint_utils,
|
||||||
|
tasks,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from fairseq.models.hubert.hubert import HubertModel
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
vq_config = {
|
||||||
|
# parameters about hubert model inference.
|
||||||
|
"model_dir": "./vq_pruned_transducer_stateless2/exp/hubert_models/",
|
||||||
|
"model_id": "hubert_xtralarge_ll60k_finetune_ls960",
|
||||||
|
"input_strategy": "AudioSamples",
|
||||||
|
"enable_spec_aug": False,
|
||||||
|
"enable_musan": False,
|
||||||
|
"total_layers": 48,
|
||||||
|
"memory_embedding_dim": 1280,
|
||||||
|
# parameters about quantizer.
|
||||||
|
"num_utts": 100,
|
||||||
|
"memory_layer": 36,
|
||||||
|
"memory_dir": "./vq_pruned_transducer_stateless2/exp/mem/",
|
||||||
|
"bytes_per_frame": 8,
|
||||||
|
"refine_iter": 5,
|
||||||
|
"enable_refine": True,
|
||||||
|
# parameters about extracted codebook index.
|
||||||
|
"data_dir": "./data/",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--job-idx",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-splits",
|
||||||
|
type=int,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantizer-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="quantizer_id" "Manully set this incase of mistake.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--refine-iter",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="number of refine iterations when extracting codebook indices",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ori-manifest-dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def load_hubert_model(params):
|
||||||
|
cfg_task = OmegaConf.create(
|
||||||
|
{
|
||||||
|
"_name": "hubert_pretraining",
|
||||||
|
"single_target": True,
|
||||||
|
"fine_tuning": True,
|
||||||
|
"data": params.model_dir,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_path = Path(params.model_dir) / (params.model_id + ".pt")
|
||||||
|
task = tasks.setup_task(cfg_task)
|
||||||
|
processor = task.target_dictionary
|
||||||
|
models, saved_cfg = checkpoint_utils.load_model_ensemble(
|
||||||
|
utils.split_paths(str(model_path), separator="\\"),
|
||||||
|
arg_overrides={},
|
||||||
|
strict=True,
|
||||||
|
suffix="",
|
||||||
|
num_shards=1,
|
||||||
|
)
|
||||||
|
model = models[0]
|
||||||
|
model.to(params.device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
return model, processor
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from HubertModel.forward to extract all middle layers output
|
||||||
|
def extract_layers_result(
|
||||||
|
model: HubertModel,
|
||||||
|
batch: Dict,
|
||||||
|
device: torch.device,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
features = batch["inputs"]
|
||||||
|
|
||||||
|
# corresponding task.normalize in fairseq
|
||||||
|
features = torch.nn.functional.layer_norm(features, features.shape)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
num_samples = supervisions["num_samples"]
|
||||||
|
B, T = features.shape
|
||||||
|
padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape(
|
||||||
|
[-1, 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
padding_mask = padding_mask.to(device)
|
||||||
|
features = features.to(device)
|
||||||
|
|
||||||
|
features = model.forward_features(features)
|
||||||
|
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
features = model.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = model.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if model.post_extract_proj is not None:
|
||||||
|
features = model.post_extract_proj(features)
|
||||||
|
|
||||||
|
_, layer_results = model.encoder(
|
||||||
|
features,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
return layer_results
|
Loading…
x
Reference in New Issue
Block a user