fix code style

This commit is contained in:
marcoyang 2023-02-13 17:55:34 +08:00
parent dae3a310f4
commit a57c54124a
3 changed files with 37 additions and 121 deletions

View File

@ -37,45 +37,31 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--tokens", "--tokens", type=str, help="Path to tokens.txt",
type=str,
help="Path to tokens.txt",
) )
parser.add_argument( parser.add_argument(
"--encoder-param-filename", "--encoder-param-filename", type=str, help="Path to encoder.ncnn.param",
type=str,
help="Path to encoder.ncnn.param",
) )
parser.add_argument( parser.add_argument(
"--encoder-bin-filename", "--encoder-bin-filename", type=str, help="Path to encoder.ncnn.bin",
type=str,
help="Path to encoder.ncnn.bin",
) )
parser.add_argument( parser.add_argument(
"--decoder-param-filename", "--decoder-param-filename", type=str, help="Path to decoder.ncnn.param",
type=str,
help="Path to decoder.ncnn.param",
) )
parser.add_argument( parser.add_argument(
"--decoder-bin-filename", "--decoder-bin-filename", type=str, help="Path to decoder.ncnn.bin",
type=str,
help="Path to decoder.ncnn.bin",
) )
parser.add_argument( parser.add_argument(
"--joiner-param-filename", "--joiner-param-filename", type=str, help="Path to joiner.ncnn.param",
type=str,
help="Path to joiner.ncnn.param",
) )
parser.add_argument( parser.add_argument(
"--joiner-bin-filename", "--joiner-bin-filename", type=str, help="Path to joiner.ncnn.bin",
type=str,
help="Path to joiner.ncnn.bin",
) )
parser.add_argument( parser.add_argument(
@ -86,23 +72,15 @@ def get_args():
) )
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim", type=int, default=512, help="Encoder output dimesion.",
type=int,
default=512,
help="Encoder output dimesion.",
) )
parser.add_argument( parser.add_argument(
"--rnn-hidden-size", "--rnn-hidden-size", type=int, default=2048, help="Dimension of feed forward.",
type=int,
default=2048,
help="Dimension of feed forward.",
) )
parser.add_argument( parser.add_argument(
"sound_filename", "sound_filename", type=str, help="Path to foo.wav",
type=str,
help="Path to foo.wav",
) )
return parser.parse_args() return parser.parse_args()
@ -286,8 +264,7 @@ def main():
logging.info(f"Reading sound files: {sound_file}") logging.info(f"Reading sound files: {sound_file}")
wave_samples = read_sound_files( wave_samples = read_sound_files(
filenames=[sound_file], filenames=[sound_file], expected_sample_rate=sample_rate,
expected_sample_rate=sample_rate,
)[0] )[0]
logging.info(wave_samples.shape) logging.info(wave_samples.shape)
@ -298,11 +275,7 @@ def main():
states = ( states = (
torch.zeros(num_encoder_layers, batch_size, d_model), torch.zeros(num_encoder_layers, batch_size, d_model),
torch.zeros( torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size,),
num_encoder_layers,
batch_size,
rnn_hidden_size,
),
) )
hyp = None hyp = None
@ -321,8 +294,7 @@ def main():
start += chunk start += chunk
online_fbank.accept_waveform( online_fbank.accept_waveform(
sampling_rate=sample_rate, sampling_rate=sample_rate, waveform=samples,
waveform=samples,
) )
while online_fbank.num_frames_ready - num_processed_frames >= segment: while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = [] frames = []

View File

@ -215,10 +215,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--sampling-rate", "--sampling-rate", type=float, default=16000, help="Sample rate of the audio",
type=float,
default=16000,
help="Sample rate of the audio",
) )
parser.add_argument( parser.add_argument(
@ -234,9 +231,7 @@ def get_parser():
def greedy_search( def greedy_search(
model: nn.Module, model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream],
encoder_out: torch.Tensor,
streams: List[Stream],
) -> None: ) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
@ -293,18 +288,12 @@ def greedy_search(
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
decoder_out = model.decoder( decoder_out = model.decoder(decoder_input, need_pad=False,)
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search( def modified_beam_search(
model: nn.Module, model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], beam: int = 4,
encoder_out: torch.Tensor,
streams: List[Stream],
beam: int = 4,
): ):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -358,9 +347,7 @@ def modified_beam_search(
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below. # as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select( current_encoder_out = torch.index_select(
current_encoder_out, current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64),
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim) ) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out, project_input=False) logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
@ -547,26 +534,19 @@ def decode_one_chunk(
pad_length = tail_length - features.size(1) pad_length = tail_length - features.size(1)
feature_lens += pad_length feature_lens += pad_length
features = torch.nn.functional.pad( features = torch.nn.functional.pad(
features, features, (0, 0, 0, pad_length), mode="constant", value=LOG_EPSILON,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
) )
# Stack states of all streams # Stack states of all streams
states = stack_states(state_list) states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder( encoder_out, encoder_out_lens, states = model.encoder(
x=features, x=features, x_lens=feature_lens, states=states,
x_lens=feature_lens,
states=states,
) )
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
greedy_search( greedy_search(
model=model, model=model, streams=streams, encoder_out=encoder_out,
streams=streams,
encoder_out=encoder_out,
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
modified_beam_search( modified_beam_search(
@ -725,10 +705,7 @@ def decode_dataset(
while len(streams) > 0: while len(streams) > 0:
finished_streams = decode_one_chunk( finished_streams = decode_one_chunk(
model=model, model=model, streams=streams, params=params, decoding_graph=decoding_graph,
streams=streams,
params=params,
decoding_graph=decoding_graph,
) )
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
@ -848,10 +825,7 @@ def main():
sp.load(bpe_model) sp.load(bpe_model)
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,)
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
@ -979,9 +953,7 @@ def main():
) )
save_results( save_results(
params=params, params=params, test_set_name=test_set, results_dict=results_dict,
test_set_name=test_set,
results_dict=results_dict,
) )
logging.info("Done!") logging.info("Done!")

View File

@ -103,38 +103,23 @@ def add_model_arguments(parser: argparse.ArgumentParser):
) )
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim", type=int, default=512, help="Encoder output dimesion.",
type=int,
default=512,
help="Encoder output dimesion.",
) )
parser.add_argument( parser.add_argument(
"--decoder-dim", "--decoder-dim", type=int, default=512, help="Decoder output dimension.",
type=int,
default=512,
help="Decoder output dimension.",
) )
parser.add_argument( parser.add_argument(
"--joiner-dim", "--joiner-dim", type=int, default=512, help="Joiner output dimension.",
type=int,
default=512,
help="Joiner output dimension.",
) )
parser.add_argument( parser.add_argument(
"--dim-feedforward", "--dim-feedforward", type=int, default=2048, help="Dimension of feed forward.",
type=int,
default=2048,
help="Dimension of feed forward.",
) )
parser.add_argument( parser.add_argument(
"--rnn-hidden-size", "--rnn-hidden-size", type=int, default=1024, help="Hidden dim for LSTM layers.",
type=int,
default=1024,
help="Hidden dim for LSTM layers.",
) )
parser.add_argument( parser.add_argument(
@ -171,10 +156,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--world-size", "--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
type=int,
default=1,
help="Number of GPUs for DDP training.",
) )
parser.add_argument( parser.add_argument(
@ -192,10 +174,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs", type=int, default=40, help="Number of epochs to train.",
type=int,
default=40,
help="Number of epochs to train.",
) )
parser.add_argument( parser.add_argument(
@ -670,7 +649,7 @@ def compute_loss(
f"simple_loss: {simple_loss}\n" f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}" f"pruned_loss: {pruned_loss}"
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params)
simple_loss = simple_loss[simple_loss_is_finite] simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite]
@ -834,7 +813,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params)
raise raise
if params.print_diagnostics and batch_idx == 30: if params.print_diagnostics and batch_idx == 30:
@ -846,9 +825,7 @@ def train_one_epoch(
and params.batch_idx_train % params.average_period == 0 and params.batch_idx_train % params.average_period == 0
): ):
update_averaged_model( update_averaged_model(
params=params, params=params, model_cur=model, model_avg=model_avg,
model_cur=model,
model_avg=model_avg,
) )
if ( if (
@ -870,9 +847,7 @@ def train_one_epoch(
) )
del params.cur_batch_idx del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
topk=params.keep_last_k,
rank=rank,
) )
if batch_idx % params.log_interval == 0 and not params.print_diagnostics: if batch_idx % params.log_interval == 0 and not params.print_diagnostics:
@ -960,10 +935,7 @@ def run(rank, world_size, args):
sp.load(bpe_model) sp.load(bpe_model)
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,)
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
@ -1014,7 +986,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 2 ** 22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)