From fb9c0c3971e49bb082ed240216f1f3fa4b8e584c Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 28 Apr 2022 00:05:06 +0800 Subject: [PATCH] decoder hubert model --- .../asr_datamodule.py | 14 +- .../hubert_decode.py | 211 ++++++++++++++++++ .../hubert_utils.py | 162 ++++++++++++++ 3 files changed, 385 insertions(+), 2 deletions(-) create mode 100755 egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py create mode 100644 egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/asr_datamodule.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/asr_datamodule.py index 8dd1459ca..e8af4362f 100644 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/asr_datamodule.py @@ -34,7 +34,7 @@ from lhotse.dataset import ( SingleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader @@ -192,6 +192,13 @@ class LibriSpeechAsrDataModule: "with training dataset. ", ) + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + def train_dataloaders( self, cuts_train: CutSet, @@ -263,6 +270,9 @@ class LibriSpeechAsrDataModule: logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( + input_strategy=AudioSamples() + if self.args.input_strategy == "AudioSamples" + else PrecomputedFeatures(), cut_transforms=transforms, input_transforms=input_transforms, return_cuts=self.args.return_cuts, @@ -371,7 +381,7 @@ class LibriSpeechAsrDataModule: test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats - else PrecomputedFeatures(), + else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) sampler = BucketingSampler( diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py new file mode 100755 index 000000000..b4a290525 --- /dev/null +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_decode.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from fairseq.data.data_utils import post_process + +from asr_datamodule import LibriSpeechAsrDataModule +from hubert_utils import ( + extract_layers_result, + load_hubert_model, + get_parser, + vq_config, +) + +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + processor, + params, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + + w2v_model = model.w2v_encoder.w2v_model + layer_results = extract_layers_result( + w2v_model, batch=batch, device=params.device + ) + + encoder_out = w2v_model.encoder.layer_norm( + layer_results[params.total_layers - 1][0] + ) + encoder_out = model.w2v_encoder.proj(encoder_out.transpose(0, 1)) + + toks = encoder_out.argmax(dim=-1) + blank = 0 + toks = [tok.unique_consecutive() for tok in toks] + hyps = [processor.string(tok[tok != blank].int().cpu()) for tok in toks] + hyps = [post_process(hyp, "letter") for hyp in hyps] + + texts = batch["supervisions"]["text"] + + this_batch = [] + assert len(hyps) == len(texts) + assert len(hyps) == len(texts) + + for hyp_text, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + hyp_words = hyp_text.split() + this_batch.append((ref_words, hyp_words)) + + results["ctc_greedy_search"].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 20 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"hubert-recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"hubert-errs-{test_set_name}-{key}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"hubert-wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict() + params.update(vars(args)) + params.update(vq_config) + + setup_logger(f"{params.exp_dir}/log-ctc_greedy_search/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + params.device = device + + model, processor = load_hubert_model(params) + + librispeech = LibriSpeechAsrDataModule(params) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + model=model, + processor=processor, + params=params, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py new file mode 100644 index 000000000..5979c8aad --- /dev/null +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/hubert_utils.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import Dict + +import torch + +from fairseq import ( + checkpoint_utils, + tasks, + utils, +) +from fairseq.models.hubert.hubert import HubertModel +from omegaconf import OmegaConf + +vq_config = { + # parameters about hubert model inference. + "model_dir": "./vq_pruned_transducer_stateless2/exp/hubert_models/", + "model_id": "hubert_xtralarge_ll60k_finetune_ls960", + "input_strategy": "AudioSamples", + "enable_spec_aug": False, + "enable_musan": False, + "total_layers": 48, + "memory_embedding_dim": 1280, + # parameters about quantizer. + "num_utts": 100, + "memory_layer": 36, + "memory_dir": "./vq_pruned_transducer_stateless2/exp/mem/", + "bytes_per_frame": 8, + "refine_iter": 5, + "enable_refine": True, + # parameters about extracted codebook index. + "data_dir": "./data/", +} + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--subset", + type=str, + ) + + parser.add_argument( + "--job-idx", + type=int, + ) + + parser.add_argument( + "--num-splits", + type=int, + ) + + parser.add_argument( + "--quantizer-id", + type=str, + default=None, + help="quantizer_id" "Manully set this incase of mistake.", + ) + + parser.add_argument( + "--refine-iter", + type=int, + default=-1, + help="number of refine iterations when extracting codebook indices", + ) + + parser.add_argument( + "--ori-manifest-dir", + type=str, + default=None, + ) + + return parser + + +def load_hubert_model(params): + cfg_task = OmegaConf.create( + { + "_name": "hubert_pretraining", + "single_target": True, + "fine_tuning": True, + "data": params.model_dir, + } + ) + model_path = Path(params.model_dir) / (params.model_id + ".pt") + task = tasks.setup_task(cfg_task) + processor = task.target_dictionary + models, saved_cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(str(model_path), separator="\\"), + arg_overrides={}, + strict=True, + suffix="", + num_shards=1, + ) + model = models[0] + model.to(params.device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + return model, processor + + +# Modified from HubertModel.forward to extract all middle layers output +def extract_layers_result( + model: HubertModel, + batch: Dict, + device: torch.device, +) -> Dict[str, torch.Tensor]: + features = batch["inputs"] + + # corresponding task.normalize in fairseq + features = torch.nn.functional.layer_norm(features, features.shape) + + supervisions = batch["supervisions"] + num_samples = supervisions["num_samples"] + B, T = features.shape + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( + [-1, 1] + ) + + padding_mask = padding_mask.to(device) + features = features.to(device) + + features = model.forward_features(features) + + features = features.transpose(1, 2) + features = model.layer_norm(features) + + if padding_mask is not None: + padding_mask = model.forward_padding_mask(features, padding_mask) + + if model.post_extract_proj is not None: + features = model.post_extract_proj(features) + + _, layer_results = model.encoder( + features, + padding_mask=padding_mask, + ) + return layer_results