mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
fix codestyle
This commit is contained in:
parent
7d217e15ab
commit
1aa2a930b4
@ -302,7 +302,9 @@ def decode_one_batch(
|
||||
en_hyps.append(en_text)
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
@ -358,7 +360,9 @@ def decode_one_batch(
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -722,13 +726,19 @@ def main():
|
||||
sp=sp,
|
||||
)
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=zh_results_dict,
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=zh_results_dict,
|
||||
)
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=en_results_dict,
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=en_results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
@ -107,7 +107,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir", type=str, default="data/lang_char", help="Path to the lang",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="Path to the lang",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -134,7 +137,8 @@ def get_parser():
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module, encoder_filename: str,
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
@ -156,7 +160,8 @@ def export_encoder_model_jit_trace(
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module, decoder_filename: str,
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
@ -177,7 +182,8 @@ def export_decoder_model_jit_trace(
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module, joiner_filename: str,
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
|
@ -37,31 +37,45 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens", type=str, help="Path to tokens.txt",
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-param-filename", type=str, help="Path to encoder.ncnn.param",
|
||||
"--encoder-param-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-bin-filename", type=str, help="Path to encoder.ncnn.bin",
|
||||
"--encoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to encoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-param-filename", type=str, help="Path to decoder.ncnn.param",
|
||||
"--decoder-param-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-bin-filename", type=str, help="Path to decoder.ncnn.bin",
|
||||
"--decoder-bin-filename",
|
||||
type=str,
|
||||
help="Path to decoder.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-param-filename", type=str, help="Path to joiner.ncnn.param",
|
||||
"--joiner-param-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.param",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-bin-filename", type=str, help="Path to joiner.ncnn.bin",
|
||||
"--joiner-bin-filename",
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -72,15 +86,23 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim", type=int, default=512, help="Encoder output dimesion.",
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Encoder output dimesion.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-hidden-size", type=int, default=2048, help="Dimension of feed forward.",
|
||||
"--rnn-hidden-size",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Dimension of feed forward.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_filename", type=str, help="Path to foo.wav",
|
||||
"sound_filename",
|
||||
type=str,
|
||||
help="Path to foo.wav",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
@ -264,7 +286,8 @@ def main():
|
||||
|
||||
logging.info(f"Reading sound files: {sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[sound_file], expected_sample_rate=sample_rate,
|
||||
filenames=[sound_file],
|
||||
expected_sample_rate=sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
@ -275,7 +298,11 @@ def main():
|
||||
|
||||
states = (
|
||||
torch.zeros(num_encoder_layers, batch_size, d_model),
|
||||
torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size,),
|
||||
torch.zeros(
|
||||
num_encoder_layers,
|
||||
batch_size,
|
||||
rnn_hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
hyp = None
|
||||
@ -294,7 +321,8 @@ def main():
|
||||
start += chunk
|
||||
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=sample_rate, waveform=samples,
|
||||
sampling_rate=sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||
frames = []
|
||||
|
@ -215,7 +215,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate", type=float, default=16000, help="Sample rate of the audio",
|
||||
"--sampling-rate",
|
||||
type=float,
|
||||
default=16000,
|
||||
help="Sample rate of the audio",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -231,7 +234,9 @@ def get_parser():
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream],
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
) -> None:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
|
||||
@ -288,12 +293,18 @@ def greedy_search(
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False,)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: nn.Module, encoder_out: torch.Tensor, streams: List[Stream], beam: int = 4,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[Stream],
|
||||
beam: int = 4,
|
||||
):
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
@ -347,7 +358,9 @@ def modified_beam_search(
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out, dim=0, index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
|
||||
@ -534,19 +547,26 @@ def decode_one_chunk(
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features, (0, 0, 0, pad_length), mode="constant", value=LOG_EPSILON,
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
|
||||
# Stack states of all streams
|
||||
states = stack_states(state_list)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder(
|
||||
x=features, x_lens=feature_lens, states=states,
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(
|
||||
model=model, streams=streams, encoder_out=encoder_out,
|
||||
model=model,
|
||||
streams=streams,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
@ -705,7 +725,10 @@ def decode_dataset(
|
||||
|
||||
while len(streams) > 0:
|
||||
finished_streams = decode_one_chunk(
|
||||
model=model, streams=streams, params=params, decoding_graph=decoding_graph,
|
||||
model=model,
|
||||
streams=streams,
|
||||
params=params,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
@ -825,7 +848,10 @@ def main():
|
||||
sp.load(bpe_model)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
@ -953,7 +979,9 @@ def main():
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict,
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
@ -103,23 +103,38 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim", type=int, default=512, help="Encoder output dimesion.",
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Encoder output dimesion.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-dim", type=int, default=512, help="Decoder output dimension.",
|
||||
"--decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Decoder output dimension.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-dim", type=int, default=512, help="Joiner output dimension.",
|
||||
"--joiner-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Joiner output dimension.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dim-feedforward", type=int, default=2048, help="Dimension of feed forward.",
|
||||
"--dim-feedforward",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Dimension of feed forward.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-hidden-size", type=int, default=1024, help="Hidden dim for LSTM layers.",
|
||||
"--rnn-hidden-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Hidden dim for LSTM layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -156,7 +171,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -174,7 +192,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs", type=int, default=40, help="Number of epochs to train.",
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=40,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -825,7 +846,9 @@ def train_one_epoch(
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params, model_cur=model, model_avg=model_avg,
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -847,7 +870,9 @@ def train_one_epoch(
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0 and not params.print_diagnostics:
|
||||
@ -935,7 +960,10 @@ def run(rank, world_size, args):
|
||||
sp.load(bpe_model)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(lexicon=lexicon, device=device,)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
@ -986,7 +1014,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
@ -210,7 +210,9 @@ class TAL_CSASRAsrDataModule:
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -355,7 +357,8 @@ class TAL_CSASRAsrDataModule:
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms, return_cuts=self.args.return_cuts,
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
@ -392,7 +395,10 @@ class TAL_CSASRAsrDataModule:
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user