2021-12-23 19:01:21 +08:00

314 lines
8.4 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from quantization import Quantizer
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.features.io import NumpyHdf5Writer
from lhotse import CutSet
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--data-dir",
type=Path,
default="./data/",
help="The experiment dir",
)
parser.add_argument(
"--mem-dir",
type=Path,
default="conformer_ctc/exp/mem",
help="The experiment dir",
)
parser.add_argument(
"--quantizer-id",
type=str,
default=None,
help="quantizer_id",
)
parser.add_argument(
"--bytes-per-frame",
type=int,
default=4,
help="The number of bytes to use to quantize each memory embeddings",
)
parser.add_argument(
"--memory-embedding-dim",
type=int,
default=512,
help="dim of memory embeddings to train quantizer",
)
parser.add_argument(
"--pretrained-model",
type=Path,
default=None,
help="use a pretrained model, e.g. a modle downloaded from model zoo",
)
parser.add_argument(
"--model-id",
type=str,
default=None,
help="a short str to introduce which models the embeddings come from"
"e.g. icefall or wav2vec2",
)
parser.add_argument(
"--mem-layer",
type=int,
default=None,
help="which layer to extract memory embedding"
"Set this manully to avoid mistake.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_codeindices(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
quantizer: None,
params: AttributeDict,
writer: None,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing memory.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- memory embeddings
"""
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
num_cuts = 0
device = params.device
cuts = []
total_frames = 0
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
_, encoder_memory, memory_mask = model(feature, supervisions)
codebook_indices = quantizer.encode(encoder_memory, as_bytes=True)
# [T, N, C] --> [N, T, C]
codebook_indices = codebook_indices.transpose(0, 1).to("cpu").numpy()
# for idx, cut in enumerate(cut_ids):
cut_list = supervisions["cut"]
assert len(cut_list) == codebook_indices.shape[0]
num_cuts += len(cut_list)
assert all(supervisions["start_frame"] == 0)
for idx, cut in enumerate(cut_list):
num_frames = (
((supervisions["num_frames"][idx] - 3) // 2 + 1) - 3
) // 2 + 1
cut.codebook_indices = writer.store_array(
key=cut.id,
value=codebook_indices[idx][:num_frames],
frame_shift=0.04,
temporal_dim=0,
start=0,
)
total_frames += num_frames
cuts += cut_list
logging.info(
f"processed {total_frames} frames and {num_cuts} cuts; {batch_idx} of {num_batches}" # noqa: E501
)
return CutSet.from_cuts(cuts)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
assert args.quantizer_id is not None
assert args.model_id is not None
assert args.mem_layer is not None
assert args.pretrained_model is not None
assert args.subset in ["clean-100", "clean-360", "other-500"]
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/mem")
logging.info("Computing memory embedings- started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
quantizer_fn = (
params.mem_dir
/ f"{params.mem_layer}layer-{params.quantizer_id}-bytes_per_frame_{params.bytes_per_frame}-quantizer.pt" # noqa: E501
)
quantizer = Quantizer(
dim=params.memory_embedding_dim,
num_codebooks=args.bytes_per_frame,
codebook_size=256,
)
quantizer.load_state_dict(torch.load(quantizer_fn))
quantizer = quantizer.to("cuda")
load_checkpoint(f"{params.pretrained_model}", model)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params["device"] = device
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
cdidx_dir = (
Path(params.data_dir)
/ f"{args.model_id}-{args.mem_layer}layer-{args.quantizer_id}-bytes_per_frame-{args.bytes_per_frame}" # noqa: E501
)
cdidx_dir.mkdir(exist_ok=True)
with NumpyHdf5Writer(
cdidx_dir
/ f"{args.model_id}-{args.mem_layer}layer-cdidx_train-{args.subset}"
) as writer:
cut_set = compute_codeindices(
model=model,
dl=train_dl,
quantizer=quantizer,
params=params,
writer=writer,
)
cut_set.to_json(cdidx_dir / f"cuts_train-{args.subset}.json.gz")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()