From a9ad9553b5b0a7b568bb094b6a3f7f707aeba185 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 23 Dec 2021 19:08:35 +0800 Subject: [PATCH] use wav2vec as a teacher model --- .../ASR/conformer_ctc/quantizer_train.py | 22 +- .../ASR/conformer_ctc/wav2vec_code_indices.py | 311 ++++++++++++++++++ .../ASR/conformer_ctc/wav2vec_decode.py | 265 +++++++++++++++ .../conformer_ctc/wav2vec_memory_embedding.py | 251 ++++++++++++++ 4 files changed, 835 insertions(+), 14 deletions(-) create mode 100755 egs/librispeech/ASR/conformer_ctc/wav2vec_code_indices.py create mode 100755 egs/librispeech/ASR/conformer_ctc/wav2vec_decode.py create mode 100755 egs/librispeech/ASR/conformer_ctc/wav2vec_memory_embedding.py diff --git a/egs/librispeech/ASR/conformer_ctc/quantizer_train.py b/egs/librispeech/ASR/conformer_ctc/quantizer_train.py index df4b8b0bf..c88a29f6c 100755 --- a/egs/librispeech/ASR/conformer_ctc/quantizer_train.py +++ b/egs/librispeech/ASR/conformer_ctc/quantizer_train.py @@ -59,7 +59,7 @@ def get_parser(): ) parser.add_argument( - "--output-layer-index", + "--mem-layer", type=int, default=None, help="which layer to extract memory embedding" @@ -69,14 +69,10 @@ def get_parser(): return parser -def initialize_memory_dataloader( - mem_dir: Path = None, output_layer_index: int = None -): +def initialize_memory_dataloader(mem_dir: Path = None, mem_layer: int = None): assert mem_dir is not None - assert output_layer_index is not None - mem_manifest_file = ( - mem_dir / f"{output_layer_index}layer-memory_manifest.json" - ) + assert mem_layer is not None + mem_manifest_file = mem_dir / f"{mem_layer}layer-memory_manifest.json" assert os.path.isfile( mem_manifest_file ), f"{mem_manifest_file} does not exist." @@ -95,14 +91,14 @@ def initialize_memory_dataloader( def main(): parser = get_parser() args = parser.parse_args() - assert args.output_layer_index is not None + assert args.mem_layer is not None setup_logger(f"{args.mem_dir}/log/quantizer_train") trainer = quantization.QuantizerTrainer( dim=args.memory_embedding_dim, bytes_per_frame=args.bytes_per_frame, device=torch.device("cuda"), ) - dl = initialize_memory_dataloader(args.mem_dir, args.output_layer_index) + dl = initialize_memory_dataloader(args.mem_dir, args.mem_layer) num_cuts = 0 done_flag = False epoch = 0 @@ -125,12 +121,10 @@ def main(): break else: epoch += 1 - dl = initialize_memory_dataloader( - args.mem_dir, args.output_layer_index - ) + dl = initialize_memory_dataloader(args.mem_dir, args.mem_layer) quantizer = trainer.get_quantizer() quantizer_fn = ( - f"{args.output_layer_index}layer-" + f"{args.mem_layer}layer-" + quantizer.get_id() + f"-bytes_per_frame_{args.bytes_per_frame}-quantizer.pt" ) diff --git a/egs/librispeech/ASR/conformer_ctc/wav2vec_code_indices.py b/egs/librispeech/ASR/conformer_ctc/wav2vec_code_indices.py new file mode 100755 index 000000000..4830d7a43 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/wav2vec_code_indices.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Tuple +from quantization import Quantizer + +import numpy as np +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from lhotse.features.io import NumpyHdf5Writer +from lhotse import CutSet + +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, +) + +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--data-dir", + type=Path, + default="./data/", + help="The experiment dir", + ) + + parser.add_argument( + "--mem-dir", + type=Path, + default="conformer_ctc/exp/mem", + help="The experiment dir", + ) + + parser.add_argument( + "--quantizer-id", + type=str, + default=None, + help="quantizer_id" "Manully set this incase of mistake.", + ) + + parser.add_argument( + "--bytes-per-frame", + type=int, + default=4, + help="The number of bytes to use to quantize each memory embeddings", + ) + + parser.add_argument( + "--memory-embedding-dim", + type=int, + default=512, + help="dim of memory embeddings to train quantizer", + ) + + parser.add_argument( + "--subset", + type=str, + default=None, + help="which subset to extract codebook index" + "clean-100, clean-360, other-500", + ) + parser.add_argument( + "--model-id", + type=str, + default="wav2vec", + help="a short str to introduce which models the embeddings come from", + ) + + parser.add_argument( + "--mem-layer", + type=int, + default=None, + help="which layer to extract memory embedding" + "Set this manully incase of mistake.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_codeindices( + model: torch.nn.Module, + processor: None, + dl: torch.utils.data.DataLoader, + quantizer: None, + params: AttributeDict, + writer: None, +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing memory. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - memory embeddings + """ + num_cuts = 0 + + cuts = [] + total_frames = 0 + for batch_idx, batch in enumerate(dl): + inputs = processor( + batch["inputs"], + sampling_rate=16000, + return_tensors="pt", + padding="longest", + ) + feature = inputs["input_values"].squeeze(0) + feature = feature.to(model.device) + B, T = feature.shape + + supervisions = batch["supervisions"] + num_samples = supervisions["num_samples"] + mask = torch.arange(0, T).expand(B, T) < num_samples.reshape([-1, 1]) + mask = mask.to(model.device) + encoder_memory = model.wav2vec2(feature, mask)[0] # [N, T, C] + + codebook_indices = quantizer.encode(encoder_memory) + + # [N, T, C] + codebook_indices = codebook_indices.to("cpu").numpy().astype(np.int16) + + cut_list = supervisions["cut"] + assert len(cut_list) == codebook_indices.shape[0] + + assert all(c.start == 0 for c in supervisions["cut"]) + for idx, cut in enumerate(cut_list): + num_frames = supervisions["num_samples"][idx] // 320 + cut.codebook_indices = writer.store_array( + key=cut.id, + value=codebook_indices[idx][:num_frames], + frame_shift=0.02, + temporal_dim=0, + start=0, + ) + total_frames += num_frames + + cuts += cut_list + num_cuts += len(cut_list) + logging.info( + f"processed {total_frames} frames and {num_cuts} cuts;" + "{batch_idx} of {num_batches}" + ) + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.subset in ["clean-100", "clean-360", "other-500"], args.subset + # disable augmentation when extracting codebook index + assert args.enable_augmentation is False + + # Manully set options + assert args.quantizer_id is not None + assert args.model_id is not None + assert args.mem_layer is not None + + assert args.return_cuts is True + assert args.concatenate_cuts is False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/codebook_index") + + logging.info("Computing memory embedings started") + logging.info(params) + + logging.info("About to create model") + quantizer_fn = ( + params.mem_dir + / f"{params.mem_layer}layer-{params.quantizer_id}-bytes_per_frame_{params.bytes_per_frame}-quantizer.pt" # noqa: E501 + ) + assert os.path.isfile(quantizer_fn), f"{quantizer_fn}" + model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self", + mem_layer=params.mem_layer, + ).to("cuda") + processor = Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" + ) + + quantizer = Quantizer( + dim=params.memory_embedding_dim, + num_codebooks=args.bytes_per_frame, + codebook_size=256, + ) + quantizer.load_state_dict(torch.load(quantizer_fn)) + quantizer = quantizer.to("cuda") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + params["device"] = device + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + train_dl = librispeech.train_dataloaders() + + cdidx_dir = ( + Path(params.data_dir) + / f"{args.model_id}-{args.mem_layer}layer-{args.quantizer_id}-bytes_per_frame-{args.bytes_per_frame}" # noqa: E501 + ) + cdidx_dir.mkdir(exist_ok=True) + + with NumpyHdf5Writer( + cdidx_dir + / f"{args.model_id}-{args.mem_layer}layer-cdidx_train-{args.subset}" + ) as writer: + cut_set = compute_codeindices( + model=model, + processor=processor, + dl=train_dl, + quantizer=quantizer, + params=params, + writer=writer, + ) + cut_set.to_json(cdidx_dir / f"cuts_train-{args.subset}.json.gz") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/wav2vec_decode.py b/egs/librispeech/ASR/conformer_ctc/wav2vec_decode.py new file mode 100755 index 000000000..1488780f6 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/wav2vec_decode.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule + +from icefall.env import get_env_info + +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--method", + type=str, + default="ctc_greedy_search", + help="Decoding method.", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + processor, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + supervisions = batch["supervisions"] + # MVN + inputs = processor( + batch["inputs"], + sampling_rate=16000, + return_tensors="pt", + padding="longest", + ) + feature = inputs["input_values"].squeeze(0) + B, T = feature.shape + num_samples = supervisions["num_samples"] + mask = torch.arange(0, T).expand(B, T) < num_samples.reshape([-1, 1]) + mask = mask.to(model.device) + feature = feature.to(model.device) + memory_embeddings = model.wav2vec2(feature, mask)[0] + logits = model.lm_head(memory_embeddings) + predicted_ids = torch.argmax(logits, dim=-1) + hyps = processor.batch_decode(predicted_ids) + + texts = batch["supervisions"]["text"] + + this_batch = [] + assert len(hyps) == len(texts) + assert len(hyps) == len(texts) + + for hyp_text, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + hyp_words = hyp_text.split() + this_batch.append((ref_words, hyp_words)) + + results["ctc_greedy_search"].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 20 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir / f"wav2vec2-recogs-{test_set_name}-{key}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir / f"wav2vec2-errs-{test_set_name}-{key}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wav2vec2-wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + # args.lang_dir = Path(args.lang_dir) + # args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + # lexicon = Lexicon(params.lang_dir) + # max_token_id = max(lexicon.tokens) + # num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" + ).to("cuda") + processor = Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" + ) + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + # CAUTION: `test_sets` is for displaying only. + # If you want to skip test-clean, you have to skip + # it inside the for loop. That is, use + # + # if test_set == 'test-clean': continue + # + test_sets = ["test-clean", "test-other"] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + results_dict = decode_dataset( + dl=test_dl, + model=model, + processor=processor, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/wav2vec_memory_embedding.py b/egs/librispeech/ASR/conformer_ctc/wav2vec_memory_embedding.py new file mode 100755 index 000000000..91de07c00 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/wav2vec_memory_embedding.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from lhotse.features.io import NumpyHdf5Writer +from lhotse import CutSet + +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, +) + +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--mem-dir", + type=str, + default="conformer_ctc/exp/mem", + help="The experiment dir", + ) + + parser.add_argument( + "--num-utts", + type=int, + default=1000, + help="number of utts to extract memory embeddings", + ) + + parser.add_argument( + "--mem-layer", + type=int, + default=None, + help="which layer to extract memory embedding" + "See: https://github.com/glynpu/transformers/pull/1/files", + ) + + parser.add_argument( + "--pretrained_model", + type=Path, + default=None, + help="use a pretrained model, e.g. a modle downloaded from model zoo", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_memory( + model: torch.nn.Module, + processor: None, + dl: torch.utils.data.DataLoader, + params: AttributeDict, + writer: None, +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing memory. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - memory embeddings + """ + + cuts = [] + total_frames = 0 + for batch_idx, batch in enumerate(dl): + inputs = processor( + batch["inputs"], + sampling_rate=16000, + return_tensors="pt", + padding="longest", + ) + feature = inputs["input_values"].squeeze(0) + feature = feature.to(model.device) + B, T = feature.shape + + supervisions = batch["supervisions"] + num_samples = supervisions["num_samples"] + mask = torch.arange(0, T).expand(B, T) < num_samples.reshape([-1, 1]) + mask = mask.to(model.device) + memory_embeddings = model.wav2vec2(feature, mask)[0] # [N, T, C] + + encoder_memory = memory_embeddings.to("cpu").numpy() + + cut_list = supervisions["cut"] + assert len(cut_list) == encoder_memory.shape[0] + assert all(c.start == 0 for c in supervisions["cut"]) + + for idx, cut in enumerate(cut_list): + num_frames = supervisions["num_samples"][idx] // 320 + cut.encoder_memory = writer.store_array( + key=cut.id, + value=encoder_memory[idx][:num_frames], + ) + total_frames += num_frames + + cuts += cut_list + logging.info(f"Processed {len(cuts)} cuts") + if len(cuts) > params.num_utts: + break + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + assert args.mem_layer is not None + assert args.mem_layer > 0 and args.mem_layer < 24 + + assert args.return_cuts is True + assert args.concatenate_cuts is False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/mem") + + logging.info("Computing memory embedings- started") + logging.info(params) + + logging.info("About to create model") + + model = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self", + output_layer_index=params.mem_layer, + ).to("cuda") + processor = Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-large-960h-lv60-self" + ) + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + params["device"] = device + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + test_dl = librispeech.test_dataloaders() # a list + + mem_dir = Path(params.mem_dir) + mem_dir.mkdir(exist_ok=True) + + enabled_datasets = { + "test_clean": test_dl[0], + } + + with NumpyHdf5Writer( + mem_dir / f"{args.mem_layer}layer-memory_embeddings" + ) as writer: + for name, dl in enabled_datasets.items(): + cut_set = compute_memory( + model=model, + processor=processor, + dl=dl, + params=params, + writer=writer, + ) + cut_set.to_json( + mem_dir / f"{args.mem_layer}layer-memory_manifest.json.gz" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()