mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
apply black and isort
This commit is contained in:
parent
1ea86de1da
commit
8178a0effc
@ -183,11 +183,13 @@ class LibriSpeechAsrDataModule:
|
||||
"--train-manifest",
|
||||
type=str,
|
||||
default="librispeech_cuts_train-clean-100.jsonl.gz",
|
||||
help="Train manifest file."
|
||||
help="Train manifest file.",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -268,11 +270,14 @@ class LibriSpeechAsrDataModule:
|
||||
logging.info("About to create dev dataset")
|
||||
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms, return_cuts=self.args.return_cuts,
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid, max_duration=self.args.max_duration, shuffle=False,
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
logging.info("About to create dev dataloader")
|
||||
@ -293,11 +298,16 @@ class LibriSpeechAsrDataModule:
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False,
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@ -311,9 +321,7 @@ class LibriSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / self.args.train_manifest
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
|
@ -92,7 +92,7 @@ class Conformer(Transformer):
|
||||
if self.subsampling_factor == 4:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
elif self.subsampling_factor == 2:
|
||||
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
|
||||
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
|
@ -32,19 +32,16 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -62,7 +59,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token", type=str, default="<star>", help="OTC token",
|
||||
"--otc-token",
|
||||
type=str,
|
||||
default="<star>",
|
||||
help="OTC token",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -137,11 +137,17 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir",
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_ctc2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir", type=str, default="data/lang_bpe_200", help="The lang dir",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_200",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -345,7 +351,11 @@ def decode_one_batch(
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "ctc-greedy-search":
|
||||
hyps, _ = ctc_greedy_search(nnet_output, memory, memory_key_padding_mask,)
|
||||
hyps, _ = ctc_greedy_search(
|
||||
nnet_output,
|
||||
memory,
|
||||
memory_key_padding_mask,
|
||||
)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(hyps)
|
||||
@ -557,7 +567,11 @@ def main():
|
||||
|
||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(max_token=max_token_id, modified=False, device=device,)
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
|
@ -120,6 +120,7 @@ class Conv2dSubsampling(torch.nn.Module):
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv2dSubsampling2(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
|
@ -66,24 +66,24 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions_otc,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
get_texts,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
@ -94,7 +94,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -112,7 +115,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs", type=int, default=20, help="Number of epochs to train.",
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -255,7 +261,18 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token", type=str, default="_<star>", help="OTC token",
|
||||
"--otc-token",
|
||||
type=str,
|
||||
default="_<star>",
|
||||
help="OTC token",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-granularity",
|
||||
type=str,
|
||||
choices=["word", "subword"],
|
||||
default="word",
|
||||
help="OTC granularity",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -374,7 +391,7 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 1,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 800, # For the 100h subset, use 800
|
||||
"alignment_interval": 100,
|
||||
"alignment_interval": 25,
|
||||
# parameters for conformer
|
||||
"feature_dim": 768,
|
||||
"subsampling_factor": 2,
|
||||
@ -585,9 +602,14 @@ def compute_loss(
|
||||
allow_self_loop_arc=params.allow_self_loop_arc,
|
||||
bypass_weight=bypass_weight,
|
||||
self_loop_weight=self_loop_weight,
|
||||
otc_granularity=params.otc_granularity,
|
||||
)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments, allow_truncate=3,)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=3,
|
||||
)
|
||||
|
||||
otc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
@ -627,18 +649,22 @@ def compute_loss(
|
||||
utt_id = utt_ids[index]
|
||||
|
||||
lattice = k2.intersect_dense(
|
||||
decoding_graph, dense_fsa_vec, params.beam_size,
|
||||
decoding_graph,
|
||||
dense_fsa_vec,
|
||||
params.beam_size,
|
||||
)
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores,
|
||||
lattice=lattice,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
hyp_ids = get_texts(best_path)[index]
|
||||
hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
|
||||
hyp_text = " ".join(hyp_text_list)
|
||||
hyp_text = "".join(hyp_text_list).replace("▁", " ")
|
||||
|
||||
logging.info(f"[utterance id]: {utt_id}")
|
||||
logging.info(f"[verbatim text]: {verbatim_text}")
|
||||
logging.info(f"[best alignment]: {hyp_text}")
|
||||
logging.info(bypass_weight)
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -770,7 +796,9 @@ def train_one_epoch(
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params, model_cur=model, model_avg=model_avg,
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -790,7 +818,9 @@ def train_one_epoch(
|
||||
rank=rank,
|
||||
)
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
@ -909,7 +909,9 @@ class Noam(object):
|
||||
|
||||
|
||||
def encoder_padding_mask(
|
||||
max_len: int, subsampling_factor: Optional[int] = 4, supervisions: Optional[Supervisions] = None
|
||||
max_len: int,
|
||||
subsampling_factor: Optional[int] = 4,
|
||||
supervisions: Optional[Supervisions] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Make mask tensor containing indexes of padded part.
|
||||
|
||||
|
@ -29,7 +29,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, S3PRLSSL, S3PRLSSLConfig, NumpyFilesWriter
|
||||
from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -71,9 +71,7 @@ def compute_ssl_librispeech():
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = S3PRLSSL(
|
||||
S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")
|
||||
)
|
||||
extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda"))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
@ -95,9 +93,7 @@ def compute_ssl_librispeech():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
|
@ -18,7 +18,9 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token", type=str, help="OTC token to be added to words.txt",
|
||||
"--otc-token",
|
||||
type=str,
|
||||
help="OTC token to be added to words.txt",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
@ -7,10 +7,11 @@ import random
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut.base import Cut
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -23,23 +24,36 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file", type=str, help="words.txt file",
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="words.txt file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token", type=str, help="OTC token in words.txt",
|
||||
"--otc-token",
|
||||
type=str,
|
||||
help="OTC token in words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sub-error-rate", type=float, default=0.0, help="Substitution error rate",
|
||||
"--sub-error-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Substitution error rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ins-error-rate", type=float, default=0.0, help="Insertion error rate",
|
||||
"--ins-error-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Insertion error rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--del-error-rate", type=float, default=0.0, help="Deletion error rate",
|
||||
"--del-error-rate",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Deletion error rate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -324,7 +324,9 @@ def lexicon_to_fst(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
|
@ -109,7 +109,9 @@ def lexicon_to_fst_no_sil(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -126,7 +128,10 @@ def lexicon_to_fst_no_sil(
|
||||
|
||||
|
||||
def generate_otc_lexicon(
|
||||
model_file: str, words: List[str], oov: str, otc_token: str,
|
||||
model_file: str,
|
||||
words: List[str],
|
||||
oov: str,
|
||||
otc_token: str,
|
||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||
"""Generate a lexicon from a BPE model.
|
||||
|
||||
@ -188,7 +193,10 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--otc-token", type=str, default="<star>", help="The OTC token in lexicon.",
|
||||
"--otc-token",
|
||||
type=str,
|
||||
default="<star>",
|
||||
help="The OTC token in lexicon.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -256,7 +264,9 @@ def main():
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon, token2id=token_sym_table, word2id=word_sym_table,
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
|
@ -38,7 +38,6 @@ class OtcTrainingGraphCompiler(object):
|
||||
initial_self_loop_weight: float = 0.0,
|
||||
bypass_weight_decay: float = 0.0,
|
||||
self_loop_weight_decay: float = 0.0,
|
||||
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -93,7 +92,11 @@ class OtcTrainingGraphCompiler(object):
|
||||
return max_token_id
|
||||
|
||||
def make_arc(
|
||||
self, from_state: int, to_state: int, symbol: Union[str, int], weight: float,
|
||||
self,
|
||||
from_state: int,
|
||||
to_state: int,
|
||||
symbol: Union[str, int],
|
||||
weight: float,
|
||||
):
|
||||
return f"{from_state} {to_state} {symbol} {weight}"
|
||||
|
||||
@ -132,7 +135,7 @@ class OtcTrainingGraphCompiler(object):
|
||||
Whether to add bypass arc to training graph for substitution
|
||||
and insertion errors (wrong or extra words in the transcript).
|
||||
allow_self_loop_arc:
|
||||
Whether to add self-loop arc to training graph for deletion
|
||||
Whether to add self-loop arc to training graph for deletion
|
||||
errors (missing words in the transcript).
|
||||
bypass_weight:
|
||||
Weight associated with bypass arc.
|
||||
@ -140,7 +143,7 @@ class OtcTrainingGraphCompiler(object):
|
||||
Weight associated with self-loop arc.
|
||||
otc_granularity:
|
||||
Use OTC token to model word or subword.
|
||||
|
||||
|
||||
Return:
|
||||
Return an FsaVec, which is the result of composing a
|
||||
CTC topology with OTC FSAs constructed from the given texts.
|
||||
@ -161,7 +164,9 @@ class OtcTrainingGraphCompiler(object):
|
||||
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
|
||||
|
||||
graph = k2.compose(
|
||||
self.ctc_topo, fsa_with_self_loop, treat_epsilons_specially=False,
|
||||
self.ctc_topo,
|
||||
fsa_with_self_loop,
|
||||
treat_epsilons_specially=False,
|
||||
)
|
||||
assert graph.requires_grad is False
|
||||
|
||||
@ -201,7 +206,10 @@ class OtcTrainingGraphCompiler(object):
|
||||
|
||||
if allow_self_loop_arc:
|
||||
self_loop_arc = self.make_arc(
|
||||
cur_state, cur_state, otc_token_id, self_loop_weight,
|
||||
cur_state,
|
||||
cur_state,
|
||||
otc_token_id,
|
||||
self_loop_weight,
|
||||
)
|
||||
arcs.append(self_loop_arc)
|
||||
|
||||
@ -225,7 +233,10 @@ class OtcTrainingGraphCompiler(object):
|
||||
|
||||
if allow_self_loop_arc:
|
||||
self_loop_arc = self.make_arc(
|
||||
cur_state, cur_state, otc_token_id, self_loop_weight,
|
||||
cur_state,
|
||||
cur_state,
|
||||
otc_token_id,
|
||||
self_loop_weight,
|
||||
)
|
||||
arcs.append(self_loop_arc)
|
||||
|
||||
|
@ -262,6 +262,7 @@ def get_texts(
|
||||
else:
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def encode_supervisions_otc(
|
||||
supervisions: dict,
|
||||
subsampling_factor: int,
|
||||
|
Loading…
x
Reference in New Issue
Block a user