mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
small fix
This commit is contained in:
parent
04ce87e307
commit
bff0822ffa
@ -108,6 +108,27 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=512,
|
default=512,
|
||||||
help="Encoder output dimesion.",
|
help="Encoder output dimesion.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Decoder output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Joiner output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dim-feedforward",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Dimension of feed forward.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rnn-hidden-size",
|
"--rnn-hidden-size",
|
||||||
@ -402,11 +423,6 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"dim_feedforward": 2048,
|
|
||||||
# parameters for decoder
|
|
||||||
"decoder_dim": 512,
|
|
||||||
# parameters for joiner
|
|
||||||
"joiner_dim": 512,
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -426,6 +442,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
aux_layer_period=params.aux_layer_period,
|
aux_layer_period=params.aux_layer_period,
|
||||||
|
is_pnnx=params.is_pnnx,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -619,6 +636,7 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
#import pdb; pdb.set_trace()
|
||||||
y = graph_compiler.texts_to_ids_with_bpe(texts)
|
y = graph_compiler.texts_to_ids_with_bpe(texts)
|
||||||
if type(y) == list:
|
if type(y) == list:
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
@ -910,8 +928,6 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
if params.full_libri is False:
|
|
||||||
params.valid_interval = 800
|
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -930,6 +946,12 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
bpe_model = params.lang_dir + "/bpe.model"
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(bpe_model)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
@ -1001,33 +1023,7 @@ def run(rank, world_size, args):
|
|||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 20.0:
|
return 1.0 <= c.duration <= 20.0
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
|
||||||
# where T is the number of feature frames after subsampling
|
|
||||||
# and S is the number of tokens in the utterance
|
|
||||||
|
|
||||||
# In ./lstm.py, the conv module uses the following expression
|
|
||||||
# for subsampling
|
|
||||||
T = ((c.num_frames - 3) // 2 - 1) // 2
|
|
||||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
|
||||||
|
|
||||||
if T < len(tokens):
|
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
|
||||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
|
||||||
f"Number of frames (after subsampling): {T}. "
|
|
||||||
f"Text: {c.supervisions[0].text}. "
|
|
||||||
f"Tokens: {tokens}. "
|
|
||||||
f"Number of tokens: {len(tokens)}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def text_normalize_for_cut(c: Cut):
|
def text_normalize_for_cut(c: Cut):
|
||||||
# Text normalize for each sample
|
# Text normalize for each sample
|
||||||
@ -1056,15 +1052,15 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts = valid_cuts.map(text_normalize_for_cut)
|
valid_cuts = valid_cuts.map(text_normalize_for_cut)
|
||||||
valid_dl = tal_csasr.valid_dataloaders(valid_cuts)
|
valid_dl = tal_csasr.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
graph_compiler=graph_compiler,
|
# graph_compiler=graph_compiler,
|
||||||
params=params,
|
# params=params,
|
||||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
# warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user