use wav2vec as a teacher model

This commit is contained in:
Guo Liyong 2021-12-23 19:08:35 +08:00
parent 3b42f0347f
commit a9ad9553b5
4 changed files with 835 additions and 14 deletions

View File

@ -59,7 +59,7 @@ def get_parser():
)
parser.add_argument(
"--output-layer-index",
"--mem-layer",
type=int,
default=None,
help="which layer to extract memory embedding"
@ -69,14 +69,10 @@ def get_parser():
return parser
def initialize_memory_dataloader(
mem_dir: Path = None, output_layer_index: int = None
):
def initialize_memory_dataloader(mem_dir: Path = None, mem_layer: int = None):
assert mem_dir is not None
assert output_layer_index is not None
mem_manifest_file = (
mem_dir / f"{output_layer_index}layer-memory_manifest.json"
)
assert mem_layer is not None
mem_manifest_file = mem_dir / f"{mem_layer}layer-memory_manifest.json"
assert os.path.isfile(
mem_manifest_file
), f"{mem_manifest_file} does not exist."
@ -95,14 +91,14 @@ def initialize_memory_dataloader(
def main():
parser = get_parser()
args = parser.parse_args()
assert args.output_layer_index is not None
assert args.mem_layer is not None
setup_logger(f"{args.mem_dir}/log/quantizer_train")
trainer = quantization.QuantizerTrainer(
dim=args.memory_embedding_dim,
bytes_per_frame=args.bytes_per_frame,
device=torch.device("cuda"),
)
dl = initialize_memory_dataloader(args.mem_dir, args.output_layer_index)
dl = initialize_memory_dataloader(args.mem_dir, args.mem_layer)
num_cuts = 0
done_flag = False
epoch = 0
@ -125,12 +121,10 @@ def main():
break
else:
epoch += 1
dl = initialize_memory_dataloader(
args.mem_dir, args.output_layer_index
)
dl = initialize_memory_dataloader(args.mem_dir, args.mem_layer)
quantizer = trainer.get_quantizer()
quantizer_fn = (
f"{args.output_layer_index}layer-"
f"{args.mem_layer}layer-"
+ quantizer.get_id()
+ f"-bytes_per_frame_{args.bytes_per_frame}-quantizer.pt"
)

View File

@ -0,0 +1,311 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import List, Tuple
from quantization import Quantizer
import numpy as np
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.features.io import NumpyHdf5Writer
from lhotse import CutSet
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
)
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--data-dir",
type=Path,
default="./data/",
help="The experiment dir",
)
parser.add_argument(
"--mem-dir",
type=Path,
default="conformer_ctc/exp/mem",
help="The experiment dir",
)
parser.add_argument(
"--quantizer-id",
type=str,
default=None,
help="quantizer_id" "Manully set this incase of mistake.",
)
parser.add_argument(
"--bytes-per-frame",
type=int,
default=4,
help="The number of bytes to use to quantize each memory embeddings",
)
parser.add_argument(
"--memory-embedding-dim",
type=int,
default=512,
help="dim of memory embeddings to train quantizer",
)
parser.add_argument(
"--subset",
type=str,
default=None,
help="which subset to extract codebook index"
"clean-100, clean-360, other-500",
)
parser.add_argument(
"--model-id",
type=str,
default="wav2vec",
help="a short str to introduce which models the embeddings come from",
)
parser.add_argument(
"--mem-layer",
type=int,
default=None,
help="which layer to extract memory embedding"
"Set this manully incase of mistake.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_codeindices(
model: torch.nn.Module,
processor: None,
dl: torch.utils.data.DataLoader,
quantizer: None,
params: AttributeDict,
writer: None,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing memory.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- memory embeddings
"""
num_cuts = 0
cuts = []
total_frames = 0
for batch_idx, batch in enumerate(dl):
inputs = processor(
batch["inputs"],
sampling_rate=16000,
return_tensors="pt",
padding="longest",
)
feature = inputs["input_values"].squeeze(0)
feature = feature.to(model.device)
B, T = feature.shape
supervisions = batch["supervisions"]
num_samples = supervisions["num_samples"]
mask = torch.arange(0, T).expand(B, T) < num_samples.reshape([-1, 1])
mask = mask.to(model.device)
encoder_memory = model.wav2vec2(feature, mask)[0] # [N, T, C]
codebook_indices = quantizer.encode(encoder_memory)
# [N, T, C]
codebook_indices = codebook_indices.to("cpu").numpy().astype(np.int16)
cut_list = supervisions["cut"]
assert len(cut_list) == codebook_indices.shape[0]
assert all(c.start == 0 for c in supervisions["cut"])
for idx, cut in enumerate(cut_list):
num_frames = supervisions["num_samples"][idx] // 320
cut.codebook_indices = writer.store_array(
key=cut.id,
value=codebook_indices[idx][:num_frames],
frame_shift=0.02,
temporal_dim=0,
start=0,
)
total_frames += num_frames
cuts += cut_list
num_cuts += len(cut_list)
logging.info(
f"processed {total_frames} frames and {num_cuts} cuts;"
"{batch_idx} of {num_batches}"
)
return CutSet.from_cuts(cuts)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.subset in ["clean-100", "clean-360", "other-500"], args.subset
# disable augmentation when extracting codebook index
assert args.enable_augmentation is False
# Manully set options
assert args.quantizer_id is not None
assert args.model_id is not None
assert args.mem_layer is not None
assert args.return_cuts is True
assert args.concatenate_cuts is False
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/codebook_index")
logging.info("Computing memory embedings started")
logging.info(params)
logging.info("About to create model")
quantizer_fn = (
params.mem_dir
/ f"{params.mem_layer}layer-{params.quantizer_id}-bytes_per_frame_{params.bytes_per_frame}-quantizer.pt" # noqa: E501
)
assert os.path.isfile(quantizer_fn), f"{quantizer_fn}"
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self",
mem_layer=params.mem_layer,
).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
quantizer = Quantizer(
dim=params.memory_embedding_dim,
num_codebooks=args.bytes_per_frame,
codebook_size=256,
)
quantizer.load_state_dict(torch.load(quantizer_fn))
quantizer = quantizer.to("cuda")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params["device"] = device
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
cdidx_dir = (
Path(params.data_dir)
/ f"{args.model_id}-{args.mem_layer}layer-{args.quantizer_id}-bytes_per_frame-{args.bytes_per_frame}" # noqa: E501
)
cdidx_dir.mkdir(exist_ok=True)
with NumpyHdf5Writer(
cdidx_dir
/ f"{args.model_id}-{args.mem_layer}layer-cdidx_train-{args.subset}"
) as writer:
cut_set = compute_codeindices(
model=model,
processor=processor,
dl=train_dl,
quantizer=quantizer,
params=params,
writer=writer,
)
cut_set.to_json(cdidx_dir / f"cuts_train-{args.subset}.json.gz")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,265 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--method",
type=str,
default="ctc_greedy_search",
help="Decoding method.",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"num_decoder_layers": 6,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
processor,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
supervisions = batch["supervisions"]
# MVN
inputs = processor(
batch["inputs"],
sampling_rate=16000,
return_tensors="pt",
padding="longest",
)
feature = inputs["input_values"].squeeze(0)
B, T = feature.shape
num_samples = supervisions["num_samples"]
mask = torch.arange(0, T).expand(B, T) < num_samples.reshape([-1, 1])
mask = mask.to(model.device)
feature = feature.to(model.device)
memory_embeddings = model.wav2vec2(feature, mask)[0]
logits = model.lm_head(memory_embeddings)
predicted_ids = torch.argmax(logits, dim=-1)
hyps = processor.batch_decode(predicted_ids)
texts = batch["supervisions"]["text"]
this_batch = []
assert len(hyps) == len(texts)
assert len(hyps) == len(texts)
for hyp_text, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
this_batch.append((ref_words, hyp_words))
results["ctc_greedy_search"].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 20 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.exp_dir / f"wav2vec2-recogs-{test_set_name}-{key}.txt"
)
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir / f"wav2vec2-errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wav2vec2-wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
# args.lang_dir = Path(args.lang_dir)
# args.lm_dir = Path(args.lm_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
# lexicon = Lexicon(params.lang_dir)
# max_token_id = max(lexicon.tokens)
# num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
results_dict = decode_dataset(
dl=test_dl,
model=model,
processor=processor,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,251 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.features.io import NumpyHdf5Writer
from lhotse import CutSet
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
)
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--mem-dir",
type=str,
default="conformer_ctc/exp/mem",
help="The experiment dir",
)
parser.add_argument(
"--num-utts",
type=int,
default=1000,
help="number of utts to extract memory embeddings",
)
parser.add_argument(
"--mem-layer",
type=int,
default=None,
help="which layer to extract memory embedding"
"See: https://github.com/glynpu/transformers/pull/1/files",
)
parser.add_argument(
"--pretrained_model",
type=Path,
default=None,
help="use a pretrained model, e.g. a modle downloaded from model zoo",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_memory(
model: torch.nn.Module,
processor: None,
dl: torch.utils.data.DataLoader,
params: AttributeDict,
writer: None,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing memory.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- memory embeddings
"""
cuts = []
total_frames = 0
for batch_idx, batch in enumerate(dl):
inputs = processor(
batch["inputs"],
sampling_rate=16000,
return_tensors="pt",
padding="longest",
)
feature = inputs["input_values"].squeeze(0)
feature = feature.to(model.device)
B, T = feature.shape
supervisions = batch["supervisions"]
num_samples = supervisions["num_samples"]
mask = torch.arange(0, T).expand(B, T) < num_samples.reshape([-1, 1])
mask = mask.to(model.device)
memory_embeddings = model.wav2vec2(feature, mask)[0] # [N, T, C]
encoder_memory = memory_embeddings.to("cpu").numpy()
cut_list = supervisions["cut"]
assert len(cut_list) == encoder_memory.shape[0]
assert all(c.start == 0 for c in supervisions["cut"])
for idx, cut in enumerate(cut_list):
num_frames = supervisions["num_samples"][idx] // 320
cut.encoder_memory = writer.store_array(
key=cut.id,
value=encoder_memory[idx][:num_frames],
)
total_frames += num_frames
cuts += cut_list
logging.info(f"Processed {len(cuts)} cuts")
if len(cuts) > params.num_utts:
break
return CutSet.from_cuts(cuts)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.mem_layer is not None
assert args.mem_layer > 0 and args.mem_layer < 24
assert args.return_cuts is True
assert args.concatenate_cuts is False
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/mem")
logging.info("Computing memory embedings- started")
logging.info(params)
logging.info("About to create model")
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self",
output_layer_index=params.mem_layer,
).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-large-960h-lv60-self"
)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params["device"] = device
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
test_dl = librispeech.test_dataloaders() # a list
mem_dir = Path(params.mem_dir)
mem_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
}
with NumpyHdf5Writer(
mem_dir / f"{args.mem_layer}layer-memory_embeddings"
) as writer:
for name, dl in enabled_datasets.items():
cut_set = compute_memory(
model=model,
processor=processor,
dl=dl,
params=params,
writer=writer,
)
cut_set.to_json(
mem_dir / f"{args.mem_layer}layer-memory_manifest.json.gz"
)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()