train with full libri

This commit is contained in:
Guo Liyong 2021-12-23 18:39:13 +08:00
parent 8985440ce1
commit 3b42f0347f
7 changed files with 874 additions and 51 deletions

View File

@ -0,0 +1,313 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from quantization import Quantizer
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.features.io import NumpyHdf5Writer
from lhotse import CutSet
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--data-dir",
type=Path,
default="./data/",
help="The experiment dir",
)
parser.add_argument(
"--mem-dir",
type=Path,
default="conformer_ctc/exp/mem",
help="The experiment dir",
)
parser.add_argument(
"--quantizer-id",
type=str,
default=None,
help="quantizer_id",
)
parser.add_argument(
"--bytes-per-frame",
type=int,
default=4,
help="The number of bytes to use to quantize each memory embeddings",
)
parser.add_argument(
"--memory-embedding-dim",
type=int,
default=512,
help="dim of memory embeddings to train quantizer",
)
parser.add_argument(
"--pretrained-model",
type=Path,
default=None,
help="use a pretrained model, e.g. a modle downloaded from model zoo",
)
parser.add_argument(
"--model-id",
type=str,
default=None,
help="a short str to introduce which models the embeddings come from"
"e.g. icefall or wav2vec2",
)
parser.add_argument(
"--mem-layer",
type=int,
default=None,
help="which layer to extract memory embedding"
"Set this manully to avoid mistake.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_codeindices(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
quantizer: None,
params: AttributeDict,
writer: None,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing memory.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- memory embeddings
"""
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
num_cuts = 0
device = params.device
cuts = []
total_frames = 0
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
_, encoder_memory, memory_mask = model(feature, supervisions)
codebook_indices = quantizer.encode(encoder_memory, as_bytes=True)
# [T, N, C] --> [N, T, C]
codebook_indices = codebook_indices.transpose(0, 1).to("cpu").numpy()
# for idx, cut in enumerate(cut_ids):
cut_list = supervisions["cut"]
assert len(cut_list) == codebook_indices.shape[0]
num_cuts += len(cut_list)
assert all(supervisions["start_frame"] == 0)
for idx, cut in enumerate(cut_list):
num_frames = (
((supervisions["num_frames"][idx] - 3) // 2 + 1) - 3
) // 2 + 1
cut.codebook_indices = writer.store_array(
key=cut.id,
value=codebook_indices[idx][:num_frames],
frame_shift=0.04,
temporal_dim=0,
start=0,
)
total_frames += num_frames
cuts += cut_list
logging.info(
f"processed {total_frames} frames and {num_cuts} cuts; {batch_idx} of {num_batches}" # noqa: E501
)
return CutSet.from_cuts(cuts)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
assert args.quantizer_id is not None
assert args.model_id is not None
assert args.mem_layer is not None
assert args.pretrained_model is not None
assert args.subset in ["clean-100", "clean-360", "other-500"]
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/mem")
logging.info("Computing memory embedings- started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
quantizer_fn = (
params.mem_dir
/ f"{params.mem_layer}layer-{params.quantizer_id}-bytes_per_frame_{params.bytes_per_frame}-quantizer.pt" # noqa: E501
)
quantizer = Quantizer(
dim=params.memory_embedding_dim,
num_codebooks=args.bytes_per_frame,
codebook_size=256,
)
quantizer.load_state_dict(torch.load(quantizer_fn))
quantizer = quantizer.to("cuda")
load_checkpoint(f"{params.pretrained_model}", model)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params["device"] = device
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
cdidx_dir = (
Path(params.data_dir)
/ f"{args.model_id}-{args.mem_layer}layer-{args.quantizer_id}-bytes_per_frame-{args.bytes_per_frame}" # noqa: E501
)
cdidx_dir.mkdir(exist_ok=True)
with NumpyHdf5Writer(
cdidx_dir
/ f"{args.model_id}-{args.mem_layer}layer-cdidx_train-{args.subset}"
) as writer:
cut_set = compute_codeindices(
model=model,
dl=train_dl,
quantizer=quantizer,
params=params,
writer=writer,
)
cut_set.to_json(cdidx_dir / f"cuts_train-{args.subset}.json.gz")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -23,6 +23,9 @@ from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask from transformer import Supervisions, Transformer, encoder_padding_mask
from prediction import JointCodebookPredictor
from ckpnt_prediction import JointCodebookLoss
from powerful_prediction import Powerful_JointCodebookLoss
class CodeIndicesNet(nn.Module): class CodeIndicesNet(nn.Module):
@ -51,18 +54,9 @@ class CodeIndicesNet(nn.Module):
self.num_codebooks = num_codebooks self.num_codebooks = num_codebooks
self.quantizer_dim = quantizer_dim self.quantizer_dim = quantizer_dim
def forward(self, memory): def forward(
""" self, memory: torch.Tensor, target: torch.Tensor
Args: ) -> torch.Tensor:
memory:
memory embeddings, with shape[T, N, C]
output:
shape [N, T, num_codebooks*quantizer_dim]
"""
x = self.linear1(memory)
return x
def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
memory: memory:
@ -75,12 +69,14 @@ class CodeIndicesNet(nn.Module):
actually it's the sum of num_codebooks CE losses actually it's the sum of num_codebooks CE losses
""" """
memory = memory.transpose(0, 1) # T, N, C --> N, T, C x = self.linear1(memory)
x = self.forward(memory)
x = x.reshape(-1, self.quantizer_dim) x = x.reshape(-1, self.quantizer_dim)
target = target.reshape(-1) target = target.reshape(-1)
assert (
x.shape[0] == target.shape[0]
), f"x.shape: {x.shape} while target.shape: {target.shape}"
ret = self.ce(x, target) ret = self.ce(x, target)
return ret return -ret, None
class Conformer(Transformer): class Conformer(Transformer):
@ -115,6 +111,9 @@ class Conformer(Transformer):
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False, use_feat_batchnorm: bool = False,
use_codebook_loss: bool = False,
num_codebooks: int = 4,
predictor: str = "predictor", # "simple_linear", "predictor", "ckpnt_predictor, powerful"
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -150,7 +149,27 @@ class Conformer(Transformer):
# and throws an error without this change. # and throws an error without this change.
self.after_norm = identity self.after_norm = identity
self.cdidxnet = CodeIndicesNet() if use_codebook_loss:
assert predictor in [
"powerful",
"predictor",
"ckpnt_predictor",
"simple_linear",
]
if predictor == "predictor":
self.cdidxnet = JointCodebookPredictor(
predictor_dim=512, num_codebooks=num_codebooks
)
elif predictor == "ckpnt_predictor":
self.cdidxnet = JointCodebookLoss(
predictor_channels=512, num_codebooks=num_codebooks
)
elif predictor == "simple_linear":
self.cdidxnet = CodeIndicesNet(num_codebooks=num_codebooks)
elif predictor == "powerful":
self.cdidxnet = Powerful_JointCodebookLoss(
predictor_channels=512, num_codebooks=num_codebooks
)
def run_encoder( def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None self, x: Tensor, supervisions: Optional[Supervisions] = None

View File

@ -499,10 +499,10 @@ def save_results(
enable_log = True enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-"
recog_path = ( recog_path = (
params.exp_dir params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}- \ / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
recogs-{test_set_name}-{key}.txt"
) )
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
@ -512,8 +512,7 @@ def save_results(
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = (
params.exp_dir params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}- \ / f"{result_file_prefix}errs-{test_set_name}-{key}.txt"
errs-{test_set_name}-{key}.txt"
) )
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
@ -528,9 +527,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.exp_dir params.exp_dir / f"{result_file_prefix}wer-summary-{test_set_name}.txt"
/ f"epoch-{params.epoch}-avg-{params.avg}- \
wer-summary-{test_set_name}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.features.io import NumpyHdf5Writer
from lhotse import CutSet
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--mem-dir",
type=str,
default="conformer_ctc/exp/mem",
help="The experiment dir",
)
parser.add_argument(
"--num-utts",
type=int,
default=1000,
help="number of utts to extract memory embeddings",
)
parser.add_argument(
"--mem-layer",
type=int,
default=None,
help="which layer to extract memory embedding",
)
parser.add_argument(
"--pretrained-model",
type=Path,
default=None,
help="use a pretrained model, e.g. a modle downloaded from model zoo",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"output_beam": 10,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def compute_memory(
model: torch.nn.Module,
dl: torch.utils.data.DataLoader,
params: AttributeDict,
writer: None,
) -> List[Tuple[str, List[int]]]:
"""Compute the framewise alignments of a dataset.
Args:
model:
The neural network model.
dl:
Dataloader containing the dataset.
params:
Parameters for computing memory.
Returns:
Return a list of tuples. Each tuple contains two entries:
- Utterance ID
- memory embeddings
"""
num_cuts = 0
device = params.device
cuts = []
total_frames = 0
for batch_idx, batch in enumerate(dl):
feature = batch["inputs"]
# at entry, feature is [N, T, C]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
_, encoder_memory, memory_mask = model(feature, supervisions)
# [T, N, C] --> [N, T, C]
encoder_memory = encoder_memory.transpose(0, 1).to("cpu").numpy()
cut_list = supervisions["cut"]
assert len(cut_list) == encoder_memory.shape[0]
assert all(supervisions["start_frame"] == 0)
for idx, cut in enumerate(cut_list):
num_frames = supervisions["num_frames"][idx]
cut.encoder_memory = writer.store_array(
key=cut.id,
value=encoder_memory[idx][:num_frames],
)
total_frames += num_frames
cuts += cut_list
num_cuts += len(cut_list)
logging.info(f"processed {total_frames} frames and {num_cuts} cuts.")
if len(cuts) > params.num_utts:
break
return CutSet.from_cuts(cuts)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert args.return_cuts is True
assert args.concatenate_cuts is False
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/mem")
logging.info("Computing memory embedings- started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
assert params.pretrained_model is not None
load_checkpoint(f"{params.pretrained_model}", model)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params["device"] = device
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
test_dl = librispeech.test_dataloaders() # a list
mem_dir = Path(params.mem_dir)
mem_dir.mkdir(exist_ok=True)
enabled_datasets = {
"test_clean": test_dl[0],
}
mem_storage = mem_dir / f"{args.mem_layer}layer-memory_embeddings"
mem_manifest = mem_dir / f"{args.mem_layer}layer-memory_manifest.json"
with NumpyHdf5Writer(mem_storage) as writer:
for name, dl in enabled_datasets.items():
cut_set = compute_memory(
model=model,
dl=dl,
params=params,
writer=writer,
)
cut_set.to_json(mem_manifest)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,143 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from lhotse import load_manifest
from lhotse.dataset import (
BucketingSampler,
K2SpeechRecognitionDataset,
)
from torch.utils.data import DataLoader
from icefall.utils import setup_logger
import torch
import quantization
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bytes-per-frame",
type=int,
default=4,
help="The number of bytes to use to quantize each memory embeddings"
"Usually, it's equal to number codebooks",
)
parser.add_argument(
"--memory-embedding-dim",
type=int,
default=1024,
help="dim of memory embeddings to train quantizer",
)
parser.add_argument(
"--mem-dir",
type=Path,
default="conformer_ctc/exp/mem",
help="The experiment dir",
)
parser.add_argument(
"--output-layer-index",
type=int,
default=None,
help="which layer to extract memory embedding"
"Specify this manully every time incase of mistakes",
)
return parser
def initialize_memory_dataloader(
mem_dir: Path = None, output_layer_index: int = None
):
assert mem_dir is not None
assert output_layer_index is not None
mem_manifest_file = (
mem_dir / f"{output_layer_index}layer-memory_manifest.json"
)
assert os.path.isfile(
mem_manifest_file
), f"{mem_manifest_file} does not exist."
cuts = load_manifest(mem_manifest_file)
dataset = K2SpeechRecognitionDataset(return_cuts=True)
max_duration = 1
sampler = BucketingSampler(
cuts,
max_duration=max_duration,
shuffle=False,
)
dl = DataLoader(dataset, batch_size=None, sampler=sampler, num_workers=4)
return dl
def main():
parser = get_parser()
args = parser.parse_args()
assert args.output_layer_index is not None
setup_logger(f"{args.mem_dir}/log/quantizer_train")
trainer = quantization.QuantizerTrainer(
dim=args.memory_embedding_dim,
bytes_per_frame=args.bytes_per_frame,
device=torch.device("cuda"),
)
dl = initialize_memory_dataloader(args.mem_dir, args.output_layer_index)
num_cuts = 0
done_flag = False
epoch = 0
while not trainer.done():
for batch in dl:
cuts = batch["supervisions"]["cut"]
embeddings = torch.cat(
[
torch.from_numpy(c.load_custom("encoder_memory"))
for c in cuts
]
)
embeddings = embeddings.to("cuda")
num_cuts += len(cuts)
trainer.step(embeddings)
if trainer.done():
done_flag = True
break
if done_flag:
break
else:
epoch += 1
dl = initialize_memory_dataloader(
args.mem_dir, args.output_layer_index
)
quantizer = trainer.get_quantizer()
quantizer_fn = (
f"{args.output_layer_index}layer-"
+ quantizer.get_id()
+ f"-bytes_per_frame_{args.bytes_per_frame}-quantizer.pt"
)
quantizer_fn = args.mem_dir / quantizer_fn
torch.save(quantizer.state_dict(), quantizer_fn)
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
main()

View File

@ -30,6 +30,7 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.cut import MonoCut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field from lhotse.dataset.collation import collate_custom_field
from torch import Tensor from torch import Tensor
@ -65,6 +66,13 @@ def get_parser():
help="Number of GPUs for DDP training.", help="Number of GPUs for DDP training.",
) )
parser.add_argument(
"--bytes-per-frame",
type=int,
default=4,
help="number of code books",
)
parser.add_argument( parser.add_argument(
"--master-port", "--master-port",
type=int, type=int,
@ -79,6 +87,13 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--predictor",
type=str,
default=None,
help="simple_linear predictor ckpnt_predictor",
)
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
@ -103,6 +118,7 @@ def get_parser():
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
Note: no tailing "/".
""", """,
) )
@ -128,7 +144,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--codebook-weight", "--codebook-weight",
type=float, type=float,
default=0.1, default=0.3,
help="""The weight of code book loss. help="""The weight of code book loss.
Note: Currently rate of ctc_loss + rate of att_loss = 1.0 Note: Currently rate of ctc_loss + rate of att_loss = 1.0
codebook_weight is independent with previous two. codebook_weight is independent with previous two.
@ -142,6 +158,14 @@ def get_parser():
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )
parser.add_argument(
"--model-id",
type=str,
default=None,
help="a short str to introduce which models the embeddings come from"
"e.g. icefall or wav2vec2",
)
return parser return parser
@ -406,27 +430,42 @@ def compute_loss(
) )
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
if params.codebook_weight != 0.0: if params.codebook_weight > 0.0 and is_training:
cuts = batch["supervisions"]["cut"] cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation. # -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indices, codebook_indices_lens = collate_custom_field( codebook_indices, codebook_indices_lens = collate_custom_field(
cuts, "codebook_indices", pad_value=-100 cuts_pre_mixed, "codebook_indices", pad_value=-100
) )
# import pdb; pdb.set_trace()
assert ( assert (
codebook_indices.shape[0] == encoder_memory.shape[1] codebook_indices.shape[0] == encoder_memory.shape[1]
) # N: batch_size ) # N: batch_size
assert (
codebook_indices.shape[1] == encoder_memory.shape[0] if "wav2vec" == params.model_id:
) # T: num frames # frame rate of wav2vec codebooks_indices is 50
# while for conformer is 25
t_expected = encoder_memory.shape[0] * 2
assert codebook_indices.shape[1] >= t_expected
codebook_indices = codebook_indices[:, 0:t_expected:2, :]
encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C
codebook_indices = codebook_indices.to(encoder_memory.device).long() codebook_indices = codebook_indices.to(encoder_memory.device).long()
codebook_loss = mmodel.cdidxnet.loss( if (
encoder_memory, target=codebook_indices params.predictor == "ckpnt_predictor"
) or params.predictor == "powerful"
):
codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices)
else:
total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices)
codebook_loss = -total_logprob
loss += params.codebook_weight * codebook_loss loss += params.codebook_weight * codebook_loss
else:
if params.codebook_weight == 0.0 and params.att_rate == 0.0:
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
@ -438,7 +477,7 @@ def compute_loss(
if params.att_rate != 0.0: if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item() info["att_loss"] = att_loss.detach().cpu().item()
if params.codebook_weight != 0.0: if params.codebook_weight > 0.0 and is_training:
info["codebook_loss"] = codebook_loss.detach().cpu().item() info["codebook_loss"] = codebook_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -633,6 +672,9 @@ def run(rank, world_size, args):
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False, vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
use_codebook_loss=True if params.codebook_weight > 0.0 else False,
num_codebooks=params.bytes_per_frame,
predictor=params.predictor,
) )
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
@ -747,7 +789,12 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) if 0.0 != args.codebook_weight:
assert -1 == args.time_warp_factor
assert not args.exp_dir.endswith("/")
args.exp_dir = Path(
f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501
)
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
world_size = args.world_size world_size = args.world_size

View File

@ -31,7 +31,7 @@ from lhotse.dataset import (
SingleCutSampler, SingleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule from icefall.dataset.datamodule import DataModule
@ -73,6 +73,21 @@ class LibriSpeechAsrDataModule(DataModule):
help="When enabled, use 960h LibriSpeech. " help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.", "Otherwise, use 100h subset.",
) )
parser.add_argument(
"--subset",
type=Path,
default=None,
help="which subset to extract codebook index"
"clean-100, clean-360, other-500",
)
group.add_argument(
"--enable-augmentation",
type=str2bool,
default=True,
help="Set to False to disable all augmentaion."
"Used when extracting codebook_indexes.",
)
group.add_argument( group.add_argument(
"--feature-dir", "--feature-dir",
type=Path, type=Path,
@ -100,6 +115,13 @@ class LibriSpeechAsrDataModule(DataModule):
help="The number of buckets for the BucketingSampler" help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
group.add_argument(
"--time-warp-factor",
type=int,
default=80,
help="Set None or less than 1 to disable"
"details in lhotse.lhotse.dataset.signal_transform",
)
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
@ -154,7 +176,16 @@ class LibriSpeechAsrDataModule(DataModule):
"collect the batches.", "collect the batches.",
) )
group.add_argument(
"--input-strategy",
type=str,
default=PrecomputedFeatures,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader: def train_dataloaders(self) -> DataLoader:
logging.info(f"enable-augmentation: {self.args.enable_augmentation}")
logging.info("About to get train cuts") logging.info("About to get train cuts")
cuts_train = self.train_cuts() cuts_train = self.train_cuts()
@ -181,6 +212,7 @@ class LibriSpeechAsrDataModule(DataModule):
input_transforms = [ input_transforms = [
SpecAugment( SpecAugment(
time_warp_factor=self.args.time_warp_factor,
num_frame_masks=2, num_frame_masks=2,
features_mask_size=27, features_mask_size=27,
num_feature_masks=2, num_feature_masks=2,
@ -189,12 +221,21 @@ class LibriSpeechAsrDataModule(DataModule):
] ]
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, input_strategy=AudioSamples()
input_transforms=input_transforms, if self.args.input_strategy == "AudioSamples"
else PrecomputedFeatures(),
cut_transforms=transforms
if self.args.enable_augmentation
else None,
input_transforms=input_transforms
if self.args.enable_augmentation
else None,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
assert self.args.enable_aug_mentation
# self.args.enable_aug_mentation==False is only tested with precomputed features. # noqa
# NOTE: the PerturbSpeed transform should be added only if we # NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage. # remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would # Add on-the-fly speed perturbation; since originally it would
@ -222,7 +263,7 @@ class LibriSpeechAsrDataModule(DataModule):
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
bucket_method="equal_duration", bucket_method="equal_duration",
drop_last=True, drop_last=True if self.args.enable_augmentation else False,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
@ -294,14 +335,20 @@ class LibriSpeechAsrDataModule(DataModule):
for cuts_test in cuts: for cuts_test in cuts:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( if self.args.input_strategy == "AudioSamples":
input_strategy=OnTheFlyFeatures( test = K2SpeechRecognitionDataset(
Fbank(FbankConfig(num_mel_bins=80)) input_strategy=AudioSamples(),
return_cuts=self.args.return_cuts,
)
else:
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
) )
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler( sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False cuts_test, max_duration=self.args.max_duration, shuffle=False
) )
@ -322,19 +369,26 @@ class LibriSpeechAsrDataModule(DataModule):
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri: if self.args.full_libri:
assert self.args.subset is None
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json"
)
cuts_train = ( cuts_train = (
cuts_train cuts_train
+ load_manifest( + load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz" self.args.feature_dir / "cuts_train-clean-360.json"
) )
+ load_manifest( + load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz" self.args.feature_dir / "cuts_train-other-500.json"
) )
) )
if self.args.subset is not None:
assert not self.args.full_libri
assert self.args.subset in ["clean-100", "clean-360", "other-500"]
cuts_train = load_manifest(
self.args.feature_dir / f"cuts_train-{self.args.subset}.json.gz"
)
return cuts_train return cuts_train
@lru_cache() @lru_cache()