2022-11-16 13:06:43 -05:00

209 lines
6.9 KiB
Python

#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import torch
from fairseq import checkpoint_utils, tasks, utils
from fairseq.data.data_utils import post_process
from omegaconf import OmegaConf
from icefall.utils import AttributeDict
def _load_hubert_model(params: AttributeDict):
"""
Load the hubert model.
The model loaded is specified by params.hubert_model_dir
and params.teacher_model_id.
Returned model carries hubert,
while processor is responsible to map model's output to human readable transcripts.
"""
cfg_task = OmegaConf.create(
{
"_name": "hubert_pretraining",
"single_target": True,
"fine_tuning": True,
"data": str(params.hubert_model_dir),
}
)
model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt")
task = tasks.setup_task(cfg_task)
processor = task.target_dictionary
models, saved_cfg = checkpoint_utils.load_model_ensemble(
utils.split_paths(str(model_path), separator="\\"),
arg_overrides={},
strict=True,
suffix="",
num_shards=1,
)
model = models[0]
model.to(params.device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
return model, processor
class HubertXlargeFineTuned:
"""
A wrapper of hubert extra large fine-tuned model.
A teacher model is responsible for:
1. load teacher model
2. extracting embeddings to train quantizer.
3. extract codebook indices
4. verify its performance with ctc_greedy_search method.
"""
def __init__(self, params: AttributeDict):
self.model, self.processor = _load_hubert_model(params)
self.w2v_model = self.model.w2v_encoder.w2v_model
self.params = params
@staticmethod
def get_params() -> AttributeDict:
"""Return a dict containing parameters defined in other modules.
Their default value conflits to hubert's requirements so they are reset as following.
"""
params = AttributeDict(
{
# parameters defined in asr_datamodule.py
"input_strategy": "AudioSamples",
"enable_musan": False,
"enable_spec_aug": False,
"return_cuts": True,
"drop_last": False,
# parameters used by quantizer
"embedding_dim": 1280,
}
)
return params
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
# Options about model loading.
parser.add_argument(
"--hubert-model-dir",
type=Path,
default="./pruned_transducer_stateless6/exp/hubert_models/",
help="path to save downloaded hubert models.",
)
parser.add_argument(
"--teacher-model-id",
type=str,
default="hubert_xtralarge_ll60k_finetune_ls960",
help="""could be one of:
[
"hubert_xtralarge_ll60k_finetune_ls960", # fine-tuned model.
"hubert_xtralarge_ll60k.pt", # pretrained model without fintuing.
]""",
)
parser.add_argument(
"--total-layers",
type=int,
default=48,
)
# Modified from HubertModel.forward to extract all middle layers output
def extract_layers_result(
self,
batch: Dict,
) -> List[torch.Tensor]:
"""
Extract activations from all layers.
"""
features = batch["inputs"]
# corresponding task.normalize in fairseq
features = torch.nn.functional.layer_norm(features, features.shape)
supervisions = batch["supervisions"]
num_samples = supervisions["num_samples"]
B, T = features.shape
padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1])
padding_mask = padding_mask.to(self.params.device)
features = features.to(self.params.device)
features = self.w2v_model.forward_features(features)
features = features.transpose(1, 2)
features = self.w2v_model.layer_norm(features)
padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask)
if self.w2v_model.post_extract_proj is not None:
features = self.w2v_model.post_extract_proj(features)
_, layer_results = self.w2v_model.encoder(
features,
padding_mask=padding_mask,
)
return layer_results
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
"""
Eextract embeddings specified by self.params.embedding_layer.
These embeddings could be used to train quantizer
or to extract codebook indexes.
The returned List[int] is valid length of each embedding.
We only want to store codebook indexes related to
these valid embeddings.
"""
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
assert all(c.start == 0 for c in cut_list)
layer_results = self.extract_layers_result(batch)
embeddings = layer_results[self.params.embedding_layer - 1][0]
encoder_embedding = embeddings.transpose(0, 1) # N, T, C
N = encoder_embedding.shape[0]
assert len(cut_list) == N
# 320 is from: 16,000 / 50 = sample_rate / hbuert output frame rate
num_frames = (supervisions["num_samples"] // 320).tolist()
return encoder_embedding, num_frames
def ctc_greedy_search(self, batch):
"""
Mainly used to verify hubert model is used correctly.
"""
layer_results = self.extract_layers_result(batch=batch)
encoder_out = self.w2v_model.encoder.layer_norm(
layer_results[self.params.total_layers - 1][0]
)
encoder_out = self.model.w2v_encoder.proj(encoder_out.transpose(0, 1))
toks = encoder_out.argmax(dim=-1)
blank = 0
toks = [tok.unique_consecutive() for tok in toks]
hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks]
hyps = [post_process(hyp, "letter") for hyp in hyps]
return hyps