mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
apply black and isort
This commit is contained in:
parent
1ea86de1da
commit
8178a0effc
@ -183,11 +183,13 @@ class LibriSpeechAsrDataModule:
|
|||||||
"--train-manifest",
|
"--train-manifest",
|
||||||
type=str,
|
type=str,
|
||||||
default="librispeech_cuts_train-clean-100.jsonl.gz",
|
default="librispeech_cuts_train-clean-100.jsonl.gz",
|
||||||
help="Train manifest file."
|
help="Train manifest file.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -268,11 +270,14 @@ class LibriSpeechAsrDataModule:
|
|||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
|
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms, return_cuts=self.args.return_cuts,
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid, max_duration=self.args.max_duration, shuffle=False,
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
@ -293,11 +298,16 @@ class LibriSpeechAsrDataModule:
|
|||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts, max_duration=self.args.max_duration, shuffle=False,
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
logging.debug("About to create test dataloader")
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@ -311,9 +321,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_100_cuts(self) -> CutSet:
|
def train_clean_100_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-clean-100 cuts")
|
logging.info("About to get train-clean-100 cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(self.args.manifest_dir / self.args.train_manifest)
|
||||||
self.args.manifest_dir / self.args.train_manifest
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_all_shuf_cuts(self) -> CutSet:
|
def train_all_shuf_cuts(self) -> CutSet:
|
||||||
|
@ -92,7 +92,7 @@ class Conformer(Transformer):
|
|||||||
if self.subsampling_factor == 4:
|
if self.subsampling_factor == 4:
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
elif self.subsampling_factor == 2:
|
elif self.subsampling_factor == 2:
|
||||||
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling2(num_features, d_model)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
@ -32,19 +32,16 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.decode import (
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
get_lattice,
|
|
||||||
one_best_decoding,
|
|
||||||
)
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -62,7 +59,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--otc-token", type=str, default="<star>", help="OTC token",
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
default="<star>",
|
||||||
|
help="OTC token",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -137,11 +137,17 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir",
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir", type=str, default="data/lang_bpe_200", help="The lang dir",
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_200",
|
||||||
|
help="The lang dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -345,7 +351,11 @@ def decode_one_batch(
|
|||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
if params.method == "ctc-greedy-search":
|
if params.method == "ctc-greedy-search":
|
||||||
hyps, _ = ctc_greedy_search(nnet_output, memory, memory_key_padding_mask,)
|
hyps, _ = ctc_greedy_search(
|
||||||
|
nnet_output,
|
||||||
|
memory,
|
||||||
|
memory_key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
hyps = bpe_model.decode(hyps)
|
hyps = bpe_model.decode(hyps)
|
||||||
@ -557,7 +567,11 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
||||||
HLG = None
|
HLG = None
|
||||||
H = k2.ctc_topo(max_token=max_token_id, modified=False, device=device,)
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||||
else:
|
else:
|
||||||
|
@ -120,6 +120,7 @@ class Conv2dSubsampling(torch.nn.Module):
|
|||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling2(torch.nn.Module):
|
class Conv2dSubsampling2(torch.nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/2 length).
|
"""Convolutional 2D subsampling (to 1/2 length).
|
||||||
|
|
||||||
|
@ -66,24 +66,24 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
|
from icefall.decode import one_best_decoding
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
encode_supervisions_otc,
|
encode_supervisions_otc,
|
||||||
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
get_texts,
|
|
||||||
)
|
)
|
||||||
from icefall.decode import one_best_decoding
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
@ -94,7 +94,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -112,7 +115,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs", type=int, default=20, help="Number of epochs to train.",
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -255,7 +261,18 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--otc-token", type=str, default="_<star>", help="OTC token",
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
default="_<star>",
|
||||||
|
help="OTC token",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--otc-granularity",
|
||||||
|
type=str,
|
||||||
|
choices=["word", "subword"],
|
||||||
|
default="word",
|
||||||
|
help="OTC granularity",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -374,7 +391,7 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 1,
|
"log_interval": 1,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 800, # For the 100h subset, use 800
|
"valid_interval": 800, # For the 100h subset, use 800
|
||||||
"alignment_interval": 100,
|
"alignment_interval": 25,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 768,
|
"feature_dim": 768,
|
||||||
"subsampling_factor": 2,
|
"subsampling_factor": 2,
|
||||||
@ -585,9 +602,14 @@ def compute_loss(
|
|||||||
allow_self_loop_arc=params.allow_self_loop_arc,
|
allow_self_loop_arc=params.allow_self_loop_arc,
|
||||||
bypass_weight=bypass_weight,
|
bypass_weight=bypass_weight,
|
||||||
self_loop_weight=self_loop_weight,
|
self_loop_weight=self_loop_weight,
|
||||||
|
otc_granularity=params.otc_granularity,
|
||||||
)
|
)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments, allow_truncate=3,)
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
nnet_output,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=3,
|
||||||
|
)
|
||||||
|
|
||||||
otc_loss = k2.ctc_loss(
|
otc_loss = k2.ctc_loss(
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -627,18 +649,22 @@ def compute_loss(
|
|||||||
utt_id = utt_ids[index]
|
utt_id = utt_ids[index]
|
||||||
|
|
||||||
lattice = k2.intersect_dense(
|
lattice = k2.intersect_dense(
|
||||||
decoding_graph, dense_fsa_vec, params.beam_size,
|
decoding_graph,
|
||||||
|
dense_fsa_vec,
|
||||||
|
params.beam_size,
|
||||||
)
|
)
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores,
|
lattice=lattice,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
)
|
)
|
||||||
hyp_ids = get_texts(best_path)[index]
|
hyp_ids = get_texts(best_path)[index]
|
||||||
hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
|
hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
|
||||||
hyp_text = " ".join(hyp_text_list)
|
hyp_text = "".join(hyp_text_list).replace("▁", " ")
|
||||||
|
|
||||||
logging.info(f"[utterance id]: {utt_id}")
|
logging.info(f"[utterance id]: {utt_id}")
|
||||||
logging.info(f"[verbatim text]: {verbatim_text}")
|
logging.info(f"[verbatim text]: {verbatim_text}")
|
||||||
logging.info(f"[best alignment]: {hyp_text}")
|
logging.info(f"[best alignment]: {hyp_text}")
|
||||||
|
logging.info(bypass_weight)
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -770,7 +796,9 @@ def train_one_epoch(
|
|||||||
and params.batch_idx_train % params.average_period == 0
|
and params.batch_idx_train % params.average_period == 0
|
||||||
):
|
):
|
||||||
update_averaged_model(
|
update_averaged_model(
|
||||||
params=params, model_cur=model, model_avg=model_avg,
|
params=params,
|
||||||
|
model_cur=model,
|
||||||
|
model_avg=model_avg,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -790,7 +818,9 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
|
out_dir=params.exp_dir,
|
||||||
|
topk=params.keep_last_k,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
@ -909,7 +909,9 @@ class Noam(object):
|
|||||||
|
|
||||||
|
|
||||||
def encoder_padding_mask(
|
def encoder_padding_mask(
|
||||||
max_len: int, subsampling_factor: Optional[int] = 4, supervisions: Optional[Supervisions] = None
|
max_len: int,
|
||||||
|
subsampling_factor: Optional[int] = 4,
|
||||||
|
supervisions: Optional[Supervisions] = None,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""Make mask tensor containing indexes of padded part.
|
"""Make mask tensor containing indexes of padded part.
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, S3PRLSSL, S3PRLSSLConfig, NumpyFilesWriter
|
from lhotse import S3PRLSSL, CutSet, NumpyFilesWriter, S3PRLSSLConfig
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -71,9 +71,7 @@ def compute_ssl_librispeech():
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
|
|
||||||
extractor = S3PRLSSL(
|
extractor = S3PRLSSL(S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda"))
|
||||||
S3PRLSSLConfig(ssl_model="wav2vec2", device="cuda")
|
|
||||||
)
|
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, m in manifests.items():
|
for partition, m in manifests.items():
|
||||||
@ -95,9 +93,7 @@ def compute_ssl_librispeech():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
@ -18,7 +18,9 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--otc-token", type=str, help="OTC token to be added to words.txt",
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
help="OTC token to be added to words.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
@ -7,10 +7,11 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut.base import Cut
|
from lhotse.cut.base import Cut
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -23,23 +24,36 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--words-file", type=str, help="words.txt file",
|
"--words-file",
|
||||||
|
type=str,
|
||||||
|
help="words.txt file",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--otc-token", type=str, help="OTC token in words.txt",
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
help="OTC token in words.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sub-error-rate", type=float, default=0.0, help="Substitution error rate",
|
"--sub-error-rate",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Substitution error rate",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ins-error-rate", type=float, default=0.0, help="Insertion error rate",
|
"--ins-error-rate",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Insertion error rate",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--del-error-rate", type=float, default=0.0, help="Deletion error rate",
|
"--del-error-rate",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Deletion error rate",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -324,7 +324,9 @@ def lexicon_to_fst(
|
|||||||
disambig_token = token2id["#0"]
|
disambig_token = token2id["#0"]
|
||||||
disambig_word = word2id["#0"]
|
disambig_word = word2id["#0"]
|
||||||
arcs = add_self_loops(
|
arcs = add_self_loops(
|
||||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_state = next_state
|
final_state = next_state
|
||||||
|
@ -109,7 +109,9 @@ def lexicon_to_fst_no_sil(
|
|||||||
disambig_token = token2id["#0"]
|
disambig_token = token2id["#0"]
|
||||||
disambig_word = word2id["#0"]
|
disambig_word = word2id["#0"]
|
||||||
arcs = add_self_loops(
|
arcs = add_self_loops(
|
||||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_state = next_state
|
final_state = next_state
|
||||||
@ -126,7 +128,10 @@ def lexicon_to_fst_no_sil(
|
|||||||
|
|
||||||
|
|
||||||
def generate_otc_lexicon(
|
def generate_otc_lexicon(
|
||||||
model_file: str, words: List[str], oov: str, otc_token: str,
|
model_file: str,
|
||||||
|
words: List[str],
|
||||||
|
oov: str,
|
||||||
|
otc_token: str,
|
||||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||||
"""Generate a lexicon from a BPE model.
|
"""Generate a lexicon from a BPE model.
|
||||||
|
|
||||||
@ -188,7 +193,10 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--otc-token", type=str, default="<star>", help="The OTC token in lexicon.",
|
"--otc-token",
|
||||||
|
type=str,
|
||||||
|
default="<star>",
|
||||||
|
help="The OTC token in lexicon.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -256,7 +264,9 @@ def main():
|
|||||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
L = lexicon_to_fst_no_sil(
|
L = lexicon_to_fst_no_sil(
|
||||||
lexicon, token2id=token_sym_table, word2id=word_sym_table,
|
lexicon,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
L_disambig = lexicon_to_fst_no_sil(
|
L_disambig = lexicon_to_fst_no_sil(
|
||||||
|
@ -38,7 +38,6 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
initial_self_loop_weight: float = 0.0,
|
initial_self_loop_weight: float = 0.0,
|
||||||
bypass_weight_decay: float = 0.0,
|
bypass_weight_decay: float = 0.0,
|
||||||
self_loop_weight_decay: float = 0.0,
|
self_loop_weight_decay: float = 0.0,
|
||||||
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -93,7 +92,11 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
return max_token_id
|
return max_token_id
|
||||||
|
|
||||||
def make_arc(
|
def make_arc(
|
||||||
self, from_state: int, to_state: int, symbol: Union[str, int], weight: float,
|
self,
|
||||||
|
from_state: int,
|
||||||
|
to_state: int,
|
||||||
|
symbol: Union[str, int],
|
||||||
|
weight: float,
|
||||||
):
|
):
|
||||||
return f"{from_state} {to_state} {symbol} {weight}"
|
return f"{from_state} {to_state} {symbol} {weight}"
|
||||||
|
|
||||||
@ -132,7 +135,7 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
Whether to add bypass arc to training graph for substitution
|
Whether to add bypass arc to training graph for substitution
|
||||||
and insertion errors (wrong or extra words in the transcript).
|
and insertion errors (wrong or extra words in the transcript).
|
||||||
allow_self_loop_arc:
|
allow_self_loop_arc:
|
||||||
Whether to add self-loop arc to training graph for deletion
|
Whether to add self-loop arc to training graph for deletion
|
||||||
errors (missing words in the transcript).
|
errors (missing words in the transcript).
|
||||||
bypass_weight:
|
bypass_weight:
|
||||||
Weight associated with bypass arc.
|
Weight associated with bypass arc.
|
||||||
@ -140,7 +143,7 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
Weight associated with self-loop arc.
|
Weight associated with self-loop arc.
|
||||||
otc_granularity:
|
otc_granularity:
|
||||||
Use OTC token to model word or subword.
|
Use OTC token to model word or subword.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Return an FsaVec, which is the result of composing a
|
Return an FsaVec, which is the result of composing a
|
||||||
CTC topology with OTC FSAs constructed from the given texts.
|
CTC topology with OTC FSAs constructed from the given texts.
|
||||||
@ -161,7 +164,9 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
|
fsa_with_self_loop = k2.arc_sort(fsa_with_self_loop)
|
||||||
|
|
||||||
graph = k2.compose(
|
graph = k2.compose(
|
||||||
self.ctc_topo, fsa_with_self_loop, treat_epsilons_specially=False,
|
self.ctc_topo,
|
||||||
|
fsa_with_self_loop,
|
||||||
|
treat_epsilons_specially=False,
|
||||||
)
|
)
|
||||||
assert graph.requires_grad is False
|
assert graph.requires_grad is False
|
||||||
|
|
||||||
@ -201,7 +206,10 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
|
|
||||||
if allow_self_loop_arc:
|
if allow_self_loop_arc:
|
||||||
self_loop_arc = self.make_arc(
|
self_loop_arc = self.make_arc(
|
||||||
cur_state, cur_state, otc_token_id, self_loop_weight,
|
cur_state,
|
||||||
|
cur_state,
|
||||||
|
otc_token_id,
|
||||||
|
self_loop_weight,
|
||||||
)
|
)
|
||||||
arcs.append(self_loop_arc)
|
arcs.append(self_loop_arc)
|
||||||
|
|
||||||
@ -225,7 +233,10 @@ class OtcTrainingGraphCompiler(object):
|
|||||||
|
|
||||||
if allow_self_loop_arc:
|
if allow_self_loop_arc:
|
||||||
self_loop_arc = self.make_arc(
|
self_loop_arc = self.make_arc(
|
||||||
cur_state, cur_state, otc_token_id, self_loop_weight,
|
cur_state,
|
||||||
|
cur_state,
|
||||||
|
otc_token_id,
|
||||||
|
self_loop_weight,
|
||||||
)
|
)
|
||||||
arcs.append(self_loop_arc)
|
arcs.append(self_loop_arc)
|
||||||
|
|
||||||
|
@ -262,6 +262,7 @@ def get_texts(
|
|||||||
else:
|
else:
|
||||||
return aux_labels.tolist()
|
return aux_labels.tolist()
|
||||||
|
|
||||||
|
|
||||||
def encode_supervisions_otc(
|
def encode_supervisions_otc(
|
||||||
supervisions: dict,
|
supervisions: dict,
|
||||||
subsampling_factor: int,
|
subsampling_factor: int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user