mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
fix code style
This commit is contained in:
parent
dae3a310f4
commit
a57c54124a
@ -37,45 +37,31 @@ def get_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--tokens", type=str, help="Path to tokens.txt",
|
||||||
type=str,
|
|
||||||
help="Path to tokens.txt",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-param-filename",
|
"--encoder-param-filename", type=str, help="Path to encoder.ncnn.param",
|
||||||
type=str,
|
|
||||||
help="Path to encoder.ncnn.param",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-bin-filename",
|
"--encoder-bin-filename", type=str, help="Path to encoder.ncnn.bin",
|
||||||
type=str,
|
|
||||||
help="Path to encoder.ncnn.bin",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-param-filename",
|
"--decoder-param-filename", type=str, help="Path to decoder.ncnn.param",
|
||||||
type=str,
|
|
||||||
help="Path to decoder.ncnn.param",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-bin-filename",
|
"--decoder-bin-filename", type=str, help="Path to decoder.ncnn.bin",
|
||||||
type=str,
|
|
||||||
help="Path to decoder.ncnn.bin",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--joiner-param-filename",
|
"--joiner-param-filename", type=str, help="Path to joiner.ncnn.param",
|
||||||
type=str,
|
|
||||||
help="Path to joiner.ncnn.param",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--joiner-bin-filename",
|
"--joiner-bin-filename", type=str, help="Path to joiner.ncnn.bin",
|
||||||
type=str,
|
|
||||||
help="Path to joiner.ncnn.bin",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -86,23 +72,15 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dim", type=int, default=512, help="Encoder output dimesion.",
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Encoder output dimesion.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rnn-hidden-size",
|
"--rnn-hidden-size", type=int, default=2048, help="Dimension of feed forward.",
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
help="Dimension of feed forward.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"sound_filename",
|
"sound_filename", type=str, help="Path to foo.wav",
|
||||||
type=str,
|
|
||||||
help="Path to foo.wav",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
@ -286,8 +264,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Reading sound files: {sound_file}")
|
logging.info(f"Reading sound files: {sound_file}")
|
||||||
wave_samples = read_sound_files(
|
wave_samples = read_sound_files(
|
||||||
filenames=[sound_file],
|
filenames=[sound_file], expected_sample_rate=sample_rate,
|
||||||
expected_sample_rate=sample_rate,
|
|
||||||
)[0]
|
)[0]
|
||||||
logging.info(wave_samples.shape)
|
logging.info(wave_samples.shape)
|
||||||
|
|
||||||
@ -298,11 +275,7 @@ def main():
|
|||||||
|
|
||||||
states = (
|
states = (
|
||||||
torch.zeros(num_encoder_layers, batch_size, d_model),
|
torch.zeros(num_encoder_layers, batch_size, d_model),
|
||||||
torch.zeros(
|
torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size,),
|
||||||
num_encoder_layers,
|
|
||||||
batch_size,
|
|
||||||
rnn_hidden_size,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hyp = None
|
hyp = None
|
||||||
@ -321,8 +294,7 @@ def main():
|
|||||||
start += chunk
|
start += chunk
|
||||||
|
|
||||||
online_fbank.accept_waveform(
|
online_fbank.accept_waveform(
|
||||||
sampling_rate=sample_rate,
|
sampling_rate=sample_rate, waveform=samples,
|
||||||
waveform=samples,
|
|
||||||
)
|
)
|
||||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
frames = []
|
frames = []
|
||||||
|
@ -215,10 +215,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sampling-rate",
|
"--sampling-rate", type=float, default=16000, help="Sample rate of the audio",
|
||||||
type=float,
|
|
||||||
default=16000,
|
|
||||||
help="Sample rate of the audio",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -234,9 +231,7 @@ def get_parser():
|
|||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
model: nn.Module,
|
model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream],
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
streams: List[Stream],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
|
||||||
@ -293,18 +288,12 @@ def greedy_search(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
decoder_out = model.decoder(
|
decoder_out = model.decoder(decoder_input, need_pad=False,)
|
||||||
decoder_input,
|
|
||||||
need_pad=False,
|
|
||||||
)
|
|
||||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
|
||||||
def modified_beam_search(
|
def modified_beam_search(
|
||||||
model: nn.Module,
|
model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], beam: int = 4,
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
streams: List[Stream],
|
|
||||||
beam: int = 4,
|
|
||||||
):
|
):
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
@ -358,9 +347,7 @@ def modified_beam_search(
|
|||||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
# as index, so we use `to(torch.int64)` below.
|
# as index, so we use `to(torch.int64)` below.
|
||||||
current_encoder_out = torch.index_select(
|
current_encoder_out = torch.index_select(
|
||||||
current_encoder_out,
|
current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
dim=0,
|
|
||||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
|
||||||
) # (num_hyps, encoder_out_dim)
|
) # (num_hyps, encoder_out_dim)
|
||||||
|
|
||||||
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
|
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
|
||||||
@ -547,26 +534,19 @@ def decode_one_chunk(
|
|||||||
pad_length = tail_length - features.size(1)
|
pad_length = tail_length - features.size(1)
|
||||||
feature_lens += pad_length
|
feature_lens += pad_length
|
||||||
features = torch.nn.functional.pad(
|
features = torch.nn.functional.pad(
|
||||||
features,
|
features, (0, 0, 0, pad_length), mode="constant", value=LOG_EPSILON,
|
||||||
(0, 0, 0, pad_length),
|
|
||||||
mode="constant",
|
|
||||||
value=LOG_EPSILON,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stack states of all streams
|
# Stack states of all streams
|
||||||
states = stack_states(state_list)
|
states = stack_states(state_list)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder(
|
encoder_out, encoder_out_lens, states = model.encoder(
|
||||||
x=features,
|
x=features, x_lens=feature_lens, states=states,
|
||||||
x_lens=feature_lens,
|
|
||||||
states=states,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
greedy_search(
|
greedy_search(
|
||||||
model=model,
|
model=model, streams=streams, encoder_out=encoder_out,
|
||||||
streams=streams,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
modified_beam_search(
|
modified_beam_search(
|
||||||
@ -725,10 +705,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
while len(streams) > 0:
|
while len(streams) > 0:
|
||||||
finished_streams = decode_one_chunk(
|
finished_streams = decode_one_chunk(
|
||||||
model=model,
|
model=model, streams=streams, params=params, decoding_graph=decoding_graph,
|
||||||
streams=streams,
|
|
||||||
params=params,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
@ -848,10 +825,7 @@ def main():
|
|||||||
sp.load(bpe_model)
|
sp.load(bpe_model)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,)
|
||||||
lexicon=lexicon,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
@ -979,9 +953,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params,
|
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||||
test_set_name=test_set,
|
|
||||||
results_dict=results_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -103,38 +103,23 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dim", type=int, default=512, help="Encoder output dimesion.",
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Encoder output dimesion.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-dim",
|
"--decoder-dim", type=int, default=512, help="Decoder output dimension.",
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Decoder output dimension.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--joiner-dim",
|
"--joiner-dim", type=int, default=512, help="Joiner output dimension.",
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Joiner output dimension.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dim-feedforward",
|
"--dim-feedforward", type=int, default=2048, help="Dimension of feed forward.",
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
help="Dimension of feed forward.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rnn-hidden-size",
|
"--rnn-hidden-size", type=int, default=1024, help="Hidden dim for LSTM layers.",
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="Hidden dim for LSTM layers.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -171,10 +156,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -192,10 +174,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs", type=int, default=40, help="Number of epochs to train.",
|
||||||
type=int,
|
|
||||||
default=40,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -670,7 +649,7 @@ def compute_loss(
|
|||||||
f"simple_loss: {simple_loss}\n"
|
f"simple_loss: {simple_loss}\n"
|
||||||
f"pruned_loss: {pruned_loss}"
|
f"pruned_loss: {pruned_loss}"
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params)
|
||||||
simple_loss = simple_loss[simple_loss_is_finite]
|
simple_loss = simple_loss[simple_loss_is_finite]
|
||||||
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
pruned_loss = pruned_loss[pruned_loss_is_finite]
|
||||||
|
|
||||||
@ -834,7 +813,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 30:
|
if params.print_diagnostics and batch_idx == 30:
|
||||||
@ -846,9 +825,7 @@ def train_one_epoch(
|
|||||||
and params.batch_idx_train % params.average_period == 0
|
and params.batch_idx_train % params.average_period == 0
|
||||||
):
|
):
|
||||||
update_averaged_model(
|
update_averaged_model(
|
||||||
params=params,
|
params=params, model_cur=model, model_avg=model_avg,
|
||||||
model_cur=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -870,9 +847,7 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
del params.cur_batch_idx
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
|
||||||
topk=params.keep_last_k,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.log_interval == 0 and not params.print_diagnostics:
|
||||||
@ -960,10 +935,7 @@ def run(rank, world_size, args):
|
|||||||
sp.load(bpe_model)
|
sp.load(bpe_model)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,)
|
||||||
lexicon=lexicon,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
@ -1014,7 +986,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
2 ** 22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user