decoder hubert model

This commit is contained in:
Guo Liyong 2022-04-28 00:05:06 +08:00
parent 9d48f1ce7d
commit fb9c0c3971
3 changed files with 385 additions and 2 deletions

View File

@ -34,7 +34,7 @@ from lhotse.dataset import (
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
@ -192,6 +192,13 @@ class LibriSpeechAsrDataModule:
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
@ -263,6 +270,9 @@ class LibriSpeechAsrDataModule:
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=AudioSamples()
if self.args.input_strategy == "AudioSamples"
else PrecomputedFeatures(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
@ -371,7 +381,7 @@ class LibriSpeechAsrDataModule:
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(

View File

@ -0,0 +1,211 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from fairseq.data.data_utils import post_process
from asr_datamodule import LibriSpeechAsrDataModule
from hubert_utils import (
extract_layers_result,
load_hubert_model,
get_parser,
vq_config,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
processor,
params,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
w2v_model = model.w2v_encoder.w2v_model
layer_results = extract_layers_result(
w2v_model, batch=batch, device=params.device
)
encoder_out = w2v_model.encoder.layer_norm(
layer_results[params.total_layers - 1][0]
)
encoder_out = model.w2v_encoder.proj(encoder_out.transpose(0, 1))
toks = encoder_out.argmax(dim=-1)
blank = 0
toks = [tok.unique_consecutive() for tok in toks]
hyps = [processor.string(tok[tok != blank].int().cpu()) for tok in toks]
hyps = [post_process(hyp, "letter") for hyp in hyps]
texts = batch["supervisions"]["text"]
this_batch = []
assert len(hyps) == len(texts)
assert len(hyps) == len(texts)
for hyp_text, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 20 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"hubert-recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir / f"hubert-errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"hubert-wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict()
params.update(vars(args))
params.update(vq_config)
setup_logger(f"{params.exp_dir}/log-ctc_greedy_search/log-decode")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
params.device = device
model, processor = load_hubert_model(params)
librispeech = LibriSpeechAsrDataModule(params)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
model=model,
processor=processor,
params=params,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,162 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Dict
import torch
from fairseq import (
checkpoint_utils,
tasks,
utils,
)
from fairseq.models.hubert.hubert import HubertModel
from omegaconf import OmegaConf
vq_config = {
# parameters about hubert model inference.
"model_dir": "./vq_pruned_transducer_stateless2/exp/hubert_models/",
"model_id": "hubert_xtralarge_ll60k_finetune_ls960",
"input_strategy": "AudioSamples",
"enable_spec_aug": False,
"enable_musan": False,
"total_layers": 48,
"memory_embedding_dim": 1280,
# parameters about quantizer.
"num_utts": 100,
"memory_layer": 36,
"memory_dir": "./vq_pruned_transducer_stateless2/exp/mem/",
"bytes_per_frame": 8,
"refine_iter": 5,
"enable_refine": True,
# parameters about extracted codebook index.
"data_dir": "./data/",
}
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--subset",
type=str,
)
parser.add_argument(
"--job-idx",
type=int,
)
parser.add_argument(
"--num-splits",
type=int,
)
parser.add_argument(
"--quantizer-id",
type=str,
default=None,
help="quantizer_id" "Manully set this incase of mistake.",
)
parser.add_argument(
"--refine-iter",
type=int,
default=-1,
help="number of refine iterations when extracting codebook indices",
)
parser.add_argument(
"--ori-manifest-dir",
type=str,
default=None,
)
return parser
def load_hubert_model(params):
cfg_task = OmegaConf.create(
{
"_name": "hubert_pretraining",
"single_target": True,
"fine_tuning": True,
"data": params.model_dir,
}
)
model_path = Path(params.model_dir) / (params.model_id + ".pt")
task = tasks.setup_task(cfg_task)
processor = task.target_dictionary
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(str(model_path), separator="\\"),
arg_overrides={},
strict=True,
suffix="",
num_shards=1,
)
model = models[0]
model.to(params.device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
return model, processor
# Modified from HubertModel.forward to extract all middle layers output
def extract_layers_result(
model: HubertModel,
batch: Dict,
device: torch.device,
) -> Dict[str, torch.Tensor]:
features = batch["inputs"]
# corresponding task.normalize in fairseq
features = torch.nn.functional.layer_norm(features, features.shape)
supervisions = batch["supervisions"]
num_samples = supervisions["num_samples"]
B, T = features.shape
padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape(
[-1, 1]
)
padding_mask = padding_mask.to(device)
features = features.to(device)
features = model.forward_features(features)
features = features.transpose(1, 2)
features = model.layer_norm(features)
if padding_mask is not None:
padding_mask = model.forward_padding_mask(features, padding_mask)
if model.post_extract_proj is not None:
features = model.post_extract_proj(features)
_, layer_results = model.encoder(
features,
padding_mask=padding_mask,
)
return layer_results