apply black and isort

This commit is contained in:
Dongji Gao 2023-09-24 11:44:39 -04:00
parent 1ea86de1da
commit 8178a0effc
13 changed files with 151 additions and 60 deletions

View File

@ -183,11 +183,13 @@ class LibriSpeechAsrDataModule:
"--train-manifest", "--train-manifest",
type=str, type=str,
default="librispeech_cuts_train-clean-100.jsonl.gz", default="librispeech_cuts_train-clean-100.jsonl.gz",
help="Train manifest file." help="Train manifest file.",
) )
def train_dataloaders( def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -268,11 +270,14 @@ class LibriSpeechAsrDataModule:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, return_cuts=self.args.return_cuts, cut_transforms=transforms,
return_cuts=self.args.return_cuts,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, max_duration=self.args.max_duration, shuffle=False, cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
@ -293,11 +298,16 @@ class LibriSpeechAsrDataModule:
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False, cuts,
max_duration=self.args.max_duration,
shuffle=False,
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
) )
return test_dl return test_dl
@ -311,9 +321,7 @@ class LibriSpeechAsrDataModule:
@lru_cache() @lru_cache()
def train_clean_100_cuts(self) -> CutSet: def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts") logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest)
self.args.manifest_dir / self.args.train_manifest
)
@lru_cache() @lru_cache()
def train_all_shuf_cuts(self) -> CutSet: def train_all_shuf_cuts(self) -> CutSet:

View File

@ -92,7 +92,7 @@ class Conformer(Transformer):
if self.subsampling_factor == 4: if self.subsampling_factor == 4:
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
elif self.subsampling_factor == 2: elif self.subsampling_factor == 2:
self.encoder_embed = Conv2dSubsampling2(num_features, d_model) self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)

View File

@ -32,19 +32,16 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.decode import ( from icefall.decode import get_lattice, one_best_decoding
get_lattice,
one_best_decoding,
)
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -62,7 +59,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--otc-token", type=str, default="<star>", help="OTC token", "--otc-token",
type=str,
default="<star>",
help="OTC token",
) )
parser.add_argument( parser.add_argument(
@ -137,11 +137,17 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir", "--exp-dir",
type=str,
default="conformer_ctc2/exp",
help="The experiment dir",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", type=str, default="data/lang_bpe_200", help="The lang dir", "--lang-dir",
type=str,
default="data/lang_bpe_200",
help="The lang dir",
) )
parser.add_argument( parser.add_argument(
@ -345,7 +351,11 @@ def decode_one_batch(
return {key: hyps} return {key: hyps}
if params.method == "ctc-greedy-search": if params.method == "ctc-greedy-search":
hyps, _ = ctc_greedy_search(nnet_output, memory, memory_key_padding_mask,) hyps, _ = ctc_greedy_search(
nnet_output,
memory,
memory_key_padding_mask,
)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...] # hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps) hyps = bpe_model.decode(hyps)
@ -557,7 +567,11 @@ def main():
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search": if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
HLG = None HLG = None
H = k2.ctc_topo(max_token=max_token_id, modified=False, device=device,) H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor() bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model")) bpe_model.load(str(params.lang_dir / "bpe.model"))
else: else:

View File

@ -120,6 +120,7 @@ class Conv2dSubsampling(torch.nn.Module):
x = self.out_balancer(x) x = self.out_balancer(x)
return x return x
class Conv2dSubsampling2(torch.nn.Module): class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length). """Convolutional 2D subsampling (to 1/2 length).

View File

@ -66,24 +66,24 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics from icefall import diagnostics
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.decode import one_best_decoding
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
encode_supervisions_otc, encode_supervisions_otc,
get_texts,
setup_logger, setup_logger,
str2bool, str2bool,
get_texts,
) )
from icefall.decode import one_best_decoding
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -94,7 +94,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.", "--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
) )
parser.add_argument( parser.add_argument(
@ -112,7 +115,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--num-epochs", type=int, default=20, help="Number of epochs to train.", "--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
) )
parser.add_argument( parser.add_argument(
@ -255,7 +261,18 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--otc-token", type=str, default="_<star>", help="OTC token", "--otc-token",
type=str,
default="_<star>",
help="OTC token",
)
parser.add_argument(
"--otc-granularity",
type=str,
choices=["word", "subword"],
default="word",
help="OTC granularity",
) )
parser.add_argument( parser.add_argument(
@ -374,7 +391,7 @@ def get_params() -> AttributeDict:
"log_interval": 1, "log_interval": 1,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 800, # For the 100h subset, use 800 "valid_interval": 800, # For the 100h subset, use 800
"alignment_interval": 100, "alignment_interval": 25,
# parameters for conformer # parameters for conformer
"feature_dim": 768, "feature_dim": 768,
"subsampling_factor": 2, "subsampling_factor": 2,
@ -585,9 +602,14 @@ def compute_loss(
allow_self_loop_arc=params.allow_self_loop_arc, allow_self_loop_arc=params.allow_self_loop_arc,
bypass_weight=bypass_weight, bypass_weight=bypass_weight,
self_loop_weight=self_loop_weight, self_loop_weight=self_loop_weight,
otc_granularity=params.otc_granularity,
) )
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments, allow_truncate=3,) dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=3,
)
otc_loss = k2.ctc_loss( otc_loss = k2.ctc_loss(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -627,18 +649,22 @@ def compute_loss(
utt_id = utt_ids[index] utt_id = utt_ids[index]
lattice = k2.intersect_dense( lattice = k2.intersect_dense(
decoding_graph, dense_fsa_vec, params.beam_size, decoding_graph,
dense_fsa_vec,
params.beam_size,
) )
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores, lattice=lattice,
use_double_scores=params.use_double_scores,
) )
hyp_ids = get_texts(best_path)[index] hyp_ids = get_texts(best_path)[index]
hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids] hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
hyp_text = " ".join(hyp_text_list) hyp_text = "".join(hyp_text_list).replace("", " ")
logging.info(f"[utterance id]: {utt_id}") logging.info(f"[utterance id]: {utt_id}")
logging.info(f"[verbatim text]: {verbatim_text}") logging.info(f"[verbatim text]: {verbatim_text}")
logging.info(f"[best alignment]: {hyp_text}") logging.info(f"[best alignment]: {hyp_text}")
logging.info(bypass_weight)
return loss, info return loss, info
@ -770,7 +796,9 @@ def train_one_epoch(
and params.batch_idx_train % params.average_period == 0 and params.batch_idx_train % params.average_period == 0
): ):
update_averaged_model( update_averaged_model(
params=params, model_cur=model, model_avg=model_avg, params=params,
model_cur=model,
model_avg=model_avg,
) )
if ( if (
@ -790,7 +818,9 @@ def train_one_epoch(
rank=rank, rank=rank,
) )
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank, out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
) )
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:

View File

@ -909,7 +909,9 @@ class Noam(object):
def encoder_padding_mask( def encoder_padding_mask(
max_len: int, subsampling_factor: Optional[int] = 4, supervisions: Optional[Supervisions] = None max_len: int,
subsampling_factor: Optional[int] = 4,
supervisions: Optional[Supervisions] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
"""Make mask tensor containing indexes of padded part. """Make mask tensor containing indexes of padded part.

View File

@ -29,7 +29,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, S3PRLSSL, S3PRLSSLConfig, NumpyFilesWriter from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -71,9 +71,7 @@ def compute_ssl_librispeech():
dataset_parts, dataset_parts,
) )
extractor = S3PRLSSL( extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda"))
S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")
)
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -95,9 +93,7 @@ def compute_ssl_librispeech():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)

View File

@ -18,7 +18,9 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--otc-token", type=str, help="OTC token to be added to words.txt", "--otc-token",
type=str,
help="OTC token to be added to words.txt",
) )
return parser.parse_args() return parser.parse_args()

View File

@ -7,10 +7,11 @@ import random
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from icefall.utils import str2bool
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut.base import Cut from lhotse.cut.base import Cut
from icefall.utils import str2bool
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -23,23 +24,36 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--words-file", type=str, help="words.txt file", "--words-file",
type=str,
help="words.txt file",
) )
parser.add_argument( parser.add_argument(
"--otc-token", type=str, help="OTC token in words.txt", "--otc-token",
type=str,
help="OTC token in words.txt",
) )
parser.add_argument( parser.add_argument(
"--sub-error-rate", type=float, default=0.0, help="Substitution error rate", "--sub-error-rate",
type=float,
default=0.0,
help="Substitution error rate",
) )
parser.add_argument( parser.add_argument(
"--ins-error-rate", type=float, default=0.0, help="Insertion error rate", "--ins-error-rate",
type=float,
default=0.0,
help="Insertion error rate",
) )
parser.add_argument( parser.add_argument(
"--del-error-rate", type=float, default=0.0, help="Deletion error rate", "--del-error-rate",
type=float,
default=0.0,
help="Deletion error rate",
) )
parser.add_argument( parser.add_argument(

View File

@ -324,7 +324,9 @@ def lexicon_to_fst(
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
disambig_word = word2id["#0"] disambig_word = word2id["#0"]
arcs = add_self_loops( arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word, arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
) )
final_state = next_state final_state = next_state

View File

@ -109,7 +109,9 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
disambig_word = word2id["#0"] disambig_word = word2id["#0"]
arcs = add_self_loops( arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word, arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
) )
final_state = next_state final_state = next_state
@ -126,7 +128,10 @@ def lexicon_to_fst_no_sil(
def generate_otc_lexicon( def generate_otc_lexicon(
model_file: str, words: List[str], oov: str, otc_token: str, model_file: str,
words: List[str],
oov: str,
otc_token: str,
) -> Tuple[Lexicon, Dict[str, int]]: ) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model. """Generate a lexicon from a BPE model.
@ -188,7 +193,10 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--otc-token", type=str, default="<star>", help="The OTC token in lexicon.", "--otc-token",
type=str,
default="<star>",
help="The OTC token in lexicon.",
) )
parser.add_argument( parser.add_argument(
@ -256,7 +264,9 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil( L = lexicon_to_fst_no_sil(
lexicon, token2id=token_sym_table, word2id=word_sym_table, lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
) )
L_disambig = lexicon_to_fst_no_sil( L_disambig = lexicon_to_fst_no_sil(

View File

@ -38,7 +38,6 @@ class OtcTrainingGraphCompiler(object):
initial_self_loop_weight: float = 0.0, initial_self_loop_weight: float = 0.0,
bypass_weight_decay: float = 0.0, bypass_weight_decay: float = 0.0,
self_loop_weight_decay: float = 0.0, self_loop_weight_decay: float = 0.0,
) -> None: ) -> None:
""" """
Args: Args:
@ -93,7 +92,11 @@ class OtcTrainingGraphCompiler(object):
return max_token_id return max_token_id
def make_arc( def make_arc(
self, from_state: int, to_state: int, symbol: Union[str, int], weight: float, self,
from_state: int,
to_state: int,
symbol: Union[str, int],
weight: float,
): ):
return f"{from_state} {to_state} {symbol} {weight}" return f"{from_state} {to_state} {symbol} {weight}"
@ -132,7 +135,7 @@ class OtcTrainingGraphCompiler(object):
Whether to add bypass arc to training graph for substitution Whether to add bypass arc to training graph for substitution
and insertion errors (wrong or extra words in the transcript). and insertion errors (wrong or extra words in the transcript).
allow_self_loop_arc: allow_self_loop_arc:
Whether to add self-loop arc to training graph for deletion Whether to add self-loop arc to training graph for deletion
errors (missing words in the transcript). errors (missing words in the transcript).
bypass_weight: bypass_weight:
Weight associated with bypass arc. Weight associated with bypass arc.
@ -140,7 +143,7 @@ class OtcTrainingGraphCompiler(object):
Weight associated with self-loop arc. Weight associated with self-loop arc.
otc_granularity: otc_granularity:
Use OTC token to model word or subword. Use OTC token to model word or subword.
Return: Return:
Return an FsaVec, which is the result of composing a Return an FsaVec, which is the result of composing a
CTC topology with OTC FSAs constructed from the given texts. CTC topology with OTC FSAs constructed from the given texts.
@ -161,7 +164,9 @@ class OtcTrainingGraphCompiler(object):
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop) fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
graph = k2.compose( graph = k2.compose(
self.ctc_topo, fsa_with_self_loop, treat_epsilons_specially=False, self.ctc_topo,
fsa_with_self_loop,
treat_epsilons_specially=False,
) )
assert graph.requires_grad is False assert graph.requires_grad is False
@ -201,7 +206,10 @@ class OtcTrainingGraphCompiler(object):
if allow_self_loop_arc: if allow_self_loop_arc:
self_loop_arc = self.make_arc( self_loop_arc = self.make_arc(
cur_state, cur_state, otc_token_id, self_loop_weight, cur_state,
cur_state,
otc_token_id,
self_loop_weight,
) )
arcs.append(self_loop_arc) arcs.append(self_loop_arc)
@ -225,7 +233,10 @@ class OtcTrainingGraphCompiler(object):
if allow_self_loop_arc: if allow_self_loop_arc:
self_loop_arc = self.make_arc( self_loop_arc = self.make_arc(
cur_state, cur_state, otc_token_id, self_loop_weight, cur_state,
cur_state,
otc_token_id,
self_loop_weight,
) )
arcs.append(self_loop_arc) arcs.append(self_loop_arc)

View File

@ -262,6 +262,7 @@ def get_texts(
else: else:
return aux_labels.tolist() return aux_labels.tolist()
def encode_supervisions_otc( def encode_supervisions_otc(
supervisions: dict, supervisions: dict,
subsampling_factor: int, subsampling_factor: int,