apply black and isort

This commit is contained in:
Dongji Gao 2023-09-24 11:44:39 -04:00
parent 1ea86de1da
commit 8178a0effc
13 changed files with 151 additions and 60 deletions

View File

@ -183,11 +183,13 @@ class LibriSpeechAsrDataModule:
"--train-manifest",
type=str,
default="librispeech_cuts_train-clean-100.jsonl.gz",
help="Train manifest file."
help="Train manifest file.",
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
@ -268,11 +270,14 @@ class LibriSpeechAsrDataModule:
logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, return_cuts=self.args.return_cuts,
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid, max_duration=self.args.max_duration, shuffle=False,
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
@ -293,11 +298,16 @@ class LibriSpeechAsrDataModule:
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False,
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@ -311,9 +321,7 @@ class LibriSpeechAsrDataModule:
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / self.args.train_manifest
)
return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:

View File

@ -92,7 +92,7 @@ class Conformer(Transformer):
if self.subsampling_factor == 4:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
elif self.subsampling_factor == 2:
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)

View File

@ -32,19 +32,16 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
one_best_decoding,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
get_texts,
@ -62,7 +59,10 @@ def get_parser():
)
parser.add_argument(
"--otc-token", type=str, default="<star>", help="OTC token",
"--otc-token",
type=str,
default="<star>",
help="OTC token",
)
parser.add_argument(
@ -137,11 +137,17 @@ def get_parser():
)
parser.add_argument(
"--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir",
"--exp-dir",
type=str,
default="conformer_ctc2/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir", type=str, default="data/lang_bpe_200", help="The lang dir",
"--lang-dir",
type=str,
default="data/lang_bpe_200",
help="The lang dir",
)
parser.add_argument(
@ -345,7 +351,11 @@ def decode_one_batch(
return {key: hyps}
if params.method == "ctc-greedy-search":
hyps, _ = ctc_greedy_search(nnet_output, memory, memory_key_padding_mask,)
hyps, _ = ctc_greedy_search(
nnet_output,
memory,
memory_key_padding_mask,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps)
@ -557,7 +567,11 @@ def main():
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
HLG = None
H = k2.ctc_topo(max_token=max_token_id, modified=False, device=device,)
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:

View File

@ -120,6 +120,7 @@ class Conv2dSubsampling(torch.nn.Module):
x = self.out_balancer(x)
return x
class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).

View File

@ -66,24 +66,24 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.decode import one_best_decoding
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions_otc,
get_texts,
setup_logger,
str2bool,
get_texts,
)
from icefall.decode import one_best_decoding
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -94,7 +94,10 @@ def get_parser():
)
parser.add_argument(
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
@ -112,7 +115,10 @@ def get_parser():
)
parser.add_argument(
"--num-epochs", type=int, default=20, help="Number of epochs to train.",
"--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
)
parser.add_argument(
@ -255,7 +261,18 @@ def get_parser():
)
parser.add_argument(
"--otc-token", type=str, default="_<star>", help="OTC token",
"--otc-token",
type=str,
default="_<star>",
help="OTC token",
)
parser.add_argument(
"--otc-granularity",
type=str,
choices=["word", "subword"],
default="word",
help="OTC granularity",
)
parser.add_argument(
@ -374,7 +391,7 @@ def get_params() -> AttributeDict:
"log_interval": 1,
"reset_interval": 200,
"valid_interval": 800, # For the 100h subset, use 800
"alignment_interval": 100,
"alignment_interval": 25,
# parameters for conformer
"feature_dim": 768,
"subsampling_factor": 2,
@ -585,9 +602,14 @@ def compute_loss(
allow_self_loop_arc=params.allow_self_loop_arc,
bypass_weight=bypass_weight,
self_loop_weight=self_loop_weight,
otc_granularity=params.otc_granularity,
)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments, allow_truncate=3,)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=3,
)
otc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
@ -627,18 +649,22 @@ def compute_loss(
utt_id = utt_ids[index]
lattice = k2.intersect_dense(
decoding_graph, dense_fsa_vec, params.beam_size,
decoding_graph,
dense_fsa_vec,
params.beam_size,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores,
lattice=lattice,
use_double_scores=params.use_double_scores,
)
hyp_ids = get_texts(best_path)[index]
hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
hyp_text = " ".join(hyp_text_list)
hyp_text = "".join(hyp_text_list).replace("", " ")
logging.info(f"[utterance id]: {utt_id}")
logging.info(f"[verbatim text]: {verbatim_text}")
logging.info(f"[best alignment]: {hyp_text}")
logging.info(bypass_weight)
return loss, info
@ -770,7 +796,9 @@ def train_one_epoch(
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params, model_cur=model, model_avg=model_avg,
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
@ -790,7 +818,9 @@ def train_one_epoch(
rank=rank,
)
remove_checkpoints(
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % params.log_interval == 0:

View File

@ -909,7 +909,9 @@ class Noam(object):
def encoder_padding_mask(
max_len: int, subsampling_factor: Optional[int] = 4, supervisions: Optional[Supervisions] = None
max_len: int,
subsampling_factor: Optional[int] = 4,
supervisions: Optional[Supervisions] = None,
) -> Optional[torch.Tensor]:
"""Make mask tensor containing indexes of padded part.

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, S3PRLSSL, S3PRLSSLConfig, NumpyFilesWriter
from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -71,9 +71,7 @@ def compute_ssl_librispeech():
dataset_parts,
)
extractor = S3PRLSSL(
S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")
)
extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda"))
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
@ -95,9 +93,7 @@ def compute_ssl_librispeech():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -18,7 +18,9 @@ def get_args():
)
parser.add_argument(
"--otc-token", type=str, help="OTC token to be added to words.txt",
"--otc-token",
type=str,
help="OTC token to be added to words.txt",
)
return parser.parse_args()

View File

@ -7,10 +7,11 @@ import random
from pathlib import Path
from typing import List
from icefall.utils import str2bool
from lhotse import CutSet, load_manifest
from lhotse.cut.base import Cut
from icefall.utils import str2bool
def get_args():
parser = argparse.ArgumentParser(
@ -23,23 +24,36 @@ def get_args():
)
parser.add_argument(
"--words-file", type=str, help="words.txt file",
"--words-file",
type=str,
help="words.txt file",
)
parser.add_argument(
"--otc-token", type=str, help="OTC token in words.txt",
"--otc-token",
type=str,
help="OTC token in words.txt",
)
parser.add_argument(
"--sub-error-rate", type=float, default=0.0, help="Substitution error rate",
"--sub-error-rate",
type=float,
default=0.0,
help="Substitution error rate",
)
parser.add_argument(
"--ins-error-rate", type=float, default=0.0, help="Insertion error rate",
"--ins-error-rate",
type=float,
default=0.0,
help="Insertion error rate",
)
parser.add_argument(
"--del-error-rate", type=float, default=0.0, help="Deletion error rate",
"--del-error-rate",
type=float,
default=0.0,
help="Deletion error rate",
)
parser.add_argument(

View File

@ -324,7 +324,9 @@ def lexicon_to_fst(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state

View File

@ -109,7 +109,9 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
@ -126,7 +128,10 @@ def lexicon_to_fst_no_sil(
def generate_otc_lexicon(
model_file: str, words: List[str], oov: str, otc_token: str,
model_file: str,
words: List[str],
oov: str,
otc_token: str,
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
@ -188,7 +193,10 @@ def get_args():
)
parser.add_argument(
"--otc-token", type=str, default="<star>", help="The OTC token in lexicon.",
"--otc-token",
type=str,
default="<star>",
help="The OTC token in lexicon.",
)
parser.add_argument(
@ -256,7 +264,9 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon, token2id=token_sym_table, word2id=word_sym_table,
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(

View File

@ -38,7 +38,6 @@ class OtcTrainingGraphCompiler(object):
initial_self_loop_weight: float = 0.0,
bypass_weight_decay: float = 0.0,
self_loop_weight_decay: float = 0.0,
) -> None:
"""
Args:
@ -93,7 +92,11 @@ class OtcTrainingGraphCompiler(object):
return max_token_id
def make_arc(
self, from_state: int, to_state: int, symbol: Union[str, int], weight: float,
self,
from_state: int,
to_state: int,
symbol: Union[str, int],
weight: float,
):
return f"{from_state} {to_state} {symbol} {weight}"
@ -132,7 +135,7 @@ class OtcTrainingGraphCompiler(object):
Whether to add bypass arc to training graph for substitution
and insertion errors (wrong or extra words in the transcript).
allow_self_loop_arc:
Whether to add self-loop arc to training graph for deletion
Whether to add self-loop arc to training graph for deletion
errors (missing words in the transcript).
bypass_weight:
Weight associated with bypass arc.
@ -140,7 +143,7 @@ class OtcTrainingGraphCompiler(object):
Weight associated with self-loop arc.
otc_granularity:
Use OTC token to model word or subword.
Return:
Return an FsaVec, which is the result of composing a
CTC topology with OTC FSAs constructed from the given texts.
@ -161,7 +164,9 @@ class OtcTrainingGraphCompiler(object):
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
graph = k2.compose(
self.ctc_topo, fsa_with_self_loop, treat_epsilons_specially=False,
self.ctc_topo,
fsa_with_self_loop,
treat_epsilons_specially=False,
)
assert graph.requires_grad is False
@ -201,7 +206,10 @@ class OtcTrainingGraphCompiler(object):
if allow_self_loop_arc:
self_loop_arc = self.make_arc(
cur_state, cur_state, otc_token_id, self_loop_weight,
cur_state,
cur_state,
otc_token_id,
self_loop_weight,
)
arcs.append(self_loop_arc)
@ -225,7 +233,10 @@ class OtcTrainingGraphCompiler(object):
if allow_self_loop_arc:
self_loop_arc = self.make_arc(
cur_state, cur_state, otc_token_id, self_loop_weight,
cur_state,
cur_state,
otc_token_id,
self_loop_weight,
)
arcs.append(self_loop_arc)

View File

@ -262,6 +262,7 @@ def get_texts(
else:
return aux_labels.tolist()
def encode_supervisions_otc(
supervisions: dict,
subsampling_factor: int,