mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
314 lines
8.4 KiB
Python
Executable File
314 lines
8.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
from quantization import Quantizer
|
|
|
|
import torch
|
|
from asr_datamodule import LibriSpeechAsrDataModule
|
|
from conformer import Conformer
|
|
from lhotse.features.io import NumpyHdf5Writer
|
|
from lhotse import CutSet
|
|
|
|
from icefall.checkpoint import load_checkpoint
|
|
from icefall.env import get_env_info
|
|
from icefall.lexicon import Lexicon
|
|
from icefall.utils import (
|
|
AttributeDict,
|
|
setup_logger,
|
|
)
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=34,
|
|
help="It specifies the checkpoint to use for decoding."
|
|
"Note: Epoch counts from 0.",
|
|
)
|
|
parser.add_argument(
|
|
"--avg",
|
|
type=int,
|
|
default=1,
|
|
help="Number of checkpoints to average. Automatically select "
|
|
"consecutive checkpoints before the checkpoint specified by "
|
|
"'--epoch'. ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lang-dir",
|
|
type=str,
|
|
default="data/lang_bpe_500",
|
|
help="The lang dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=str,
|
|
default="conformer_ctc/exp",
|
|
help="The experiment dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--data-dir",
|
|
type=Path,
|
|
default="./data/",
|
|
help="The experiment dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mem-dir",
|
|
type=Path,
|
|
default="conformer_ctc/exp/mem",
|
|
help="The experiment dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--quantizer-id",
|
|
type=str,
|
|
default=None,
|
|
help="quantizer_id",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bytes-per-frame",
|
|
type=int,
|
|
default=4,
|
|
help="The number of bytes to use to quantize each memory embeddings",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--memory-embedding-dim",
|
|
type=int,
|
|
default=512,
|
|
help="dim of memory embeddings to train quantizer",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--pretrained-model",
|
|
type=Path,
|
|
default=None,
|
|
help="use a pretrained model, e.g. a modle downloaded from model zoo",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model-id",
|
|
type=str,
|
|
default=None,
|
|
help="a short str to introduce which models the embeddings come from"
|
|
"e.g. icefall or wav2vec2",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mem-layer",
|
|
type=int,
|
|
default=None,
|
|
help="which layer to extract memory embedding"
|
|
"Set this manully to avoid mistake.",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_params() -> AttributeDict:
|
|
params = AttributeDict(
|
|
{
|
|
"feature_dim": 80,
|
|
"nhead": 8,
|
|
"attention_dim": 512,
|
|
"subsampling_factor": 4,
|
|
"num_decoder_layers": 6,
|
|
"vgg_frontend": False,
|
|
"use_feat_batchnorm": True,
|
|
"output_beam": 10,
|
|
"use_double_scores": True,
|
|
"env_info": get_env_info(),
|
|
}
|
|
)
|
|
return params
|
|
|
|
|
|
def compute_codeindices(
|
|
model: torch.nn.Module,
|
|
dl: torch.utils.data.DataLoader,
|
|
quantizer: None,
|
|
params: AttributeDict,
|
|
writer: None,
|
|
) -> List[Tuple[str, List[int]]]:
|
|
"""Compute the framewise alignments of a dataset.
|
|
|
|
Args:
|
|
model:
|
|
The neural network model.
|
|
dl:
|
|
Dataloader containing the dataset.
|
|
params:
|
|
Parameters for computing memory.
|
|
Returns:
|
|
Return a list of tuples. Each tuple contains two entries:
|
|
- Utterance ID
|
|
- memory embeddings
|
|
"""
|
|
try:
|
|
num_batches = len(dl)
|
|
except TypeError:
|
|
num_batches = "?"
|
|
num_cuts = 0
|
|
|
|
device = params.device
|
|
cuts = []
|
|
total_frames = 0
|
|
for batch_idx, batch in enumerate(dl):
|
|
feature = batch["inputs"]
|
|
|
|
# at entry, feature is [N, T, C]
|
|
assert feature.ndim == 3
|
|
feature = feature.to(device)
|
|
|
|
supervisions = batch["supervisions"]
|
|
|
|
_, encoder_memory, memory_mask = model(feature, supervisions)
|
|
codebook_indices = quantizer.encode(encoder_memory, as_bytes=True)
|
|
|
|
# [T, N, C] --> [N, T, C]
|
|
codebook_indices = codebook_indices.transpose(0, 1).to("cpu").numpy()
|
|
|
|
# for idx, cut in enumerate(cut_ids):
|
|
cut_list = supervisions["cut"]
|
|
assert len(cut_list) == codebook_indices.shape[0]
|
|
num_cuts += len(cut_list)
|
|
assert all(supervisions["start_frame"] == 0)
|
|
for idx, cut in enumerate(cut_list):
|
|
num_frames = (
|
|
((supervisions["num_frames"][idx] - 3) // 2 + 1) - 3
|
|
) // 2 + 1
|
|
cut.codebook_indices = writer.store_array(
|
|
key=cut.id,
|
|
value=codebook_indices[idx][:num_frames],
|
|
frame_shift=0.04,
|
|
temporal_dim=0,
|
|
start=0,
|
|
)
|
|
total_frames += num_frames
|
|
|
|
cuts += cut_list
|
|
logging.info(
|
|
f"processed {total_frames} frames and {num_cuts} cuts; {batch_idx} of {num_batches}" # noqa: E501
|
|
)
|
|
return CutSet.from_cuts(cuts)
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = get_parser()
|
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
|
args = parser.parse_args()
|
|
|
|
assert args.return_cuts is True
|
|
assert args.concatenate_cuts is False
|
|
assert args.quantizer_id is not None
|
|
assert args.model_id is not None
|
|
assert args.mem_layer is not None
|
|
assert args.pretrained_model is not None
|
|
assert args.subset in ["clean-100", "clean-360", "other-500"]
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
setup_logger(f"{params.exp_dir}/log/mem")
|
|
|
|
logging.info("Computing memory embedings- started")
|
|
logging.info(params)
|
|
|
|
lexicon = Lexicon(params.lang_dir)
|
|
max_token_id = max(lexicon.tokens)
|
|
num_classes = max_token_id + 1 # +1 for the blank
|
|
|
|
logging.info("About to create model")
|
|
model = Conformer(
|
|
num_features=params.feature_dim,
|
|
nhead=params.nhead,
|
|
d_model=params.attention_dim,
|
|
num_classes=num_classes,
|
|
subsampling_factor=params.subsampling_factor,
|
|
num_decoder_layers=params.num_decoder_layers,
|
|
vgg_frontend=params.vgg_frontend,
|
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
)
|
|
|
|
quantizer_fn = (
|
|
params.mem_dir
|
|
/ f"{params.mem_layer}layer-{params.quantizer_id}-bytes_per_frame_{params.bytes_per_frame}-quantizer.pt" # noqa: E501
|
|
)
|
|
|
|
quantizer = Quantizer(
|
|
dim=params.memory_embedding_dim,
|
|
num_codebooks=args.bytes_per_frame,
|
|
codebook_size=256,
|
|
)
|
|
quantizer.load_state_dict(torch.load(quantizer_fn))
|
|
quantizer = quantizer.to("cuda")
|
|
|
|
load_checkpoint(f"{params.pretrained_model}", model)
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
params["device"] = device
|
|
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
librispeech = LibriSpeechAsrDataModule(args)
|
|
|
|
train_dl = librispeech.train_dataloaders()
|
|
|
|
cdidx_dir = (
|
|
Path(params.data_dir)
|
|
/ f"{args.model_id}-{args.mem_layer}layer-{args.quantizer_id}-bytes_per_frame-{args.bytes_per_frame}" # noqa: E501
|
|
)
|
|
cdidx_dir.mkdir(exist_ok=True)
|
|
|
|
with NumpyHdf5Writer(
|
|
cdidx_dir
|
|
/ f"{args.model_id}-{args.mem_layer}layer-cdidx_train-{args.subset}"
|
|
) as writer:
|
|
cut_set = compute_codeindices(
|
|
model=model,
|
|
dl=train_dl,
|
|
quantizer=quantizer,
|
|
params=params,
|
|
writer=writer,
|
|
)
|
|
cut_set.to_json(cdidx_dir / f"cuts_train-{args.subset}.json.gz")
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|