Fix for black

This commit is contained in:
Yifan Yang 2023-06-14 18:28:17 +08:00
parent c43d4ced9e
commit 5b049a1a3a
4 changed files with 39 additions and 71 deletions

View File

@ -273,8 +273,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -425,10 +424,7 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif ( elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -559,9 +555,7 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info( logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -594,8 +588,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -656,9 +649,7 @@ def main():
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -690,9 +681,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -719,9 +710,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -780,9 +771,7 @@ def main():
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
else: else:
word_table = None word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None word_table = None

View File

@ -257,6 +257,7 @@ def get_parser():
class EncoderModel(nn.Module): class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed""" """A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__() super().__init__()
self.encoder = encoder self.encoder = encoder
@ -275,9 +276,7 @@ class EncoderModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder( encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens

View File

@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
) )
batch_states.append(cached_embed_left_pad) batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat( processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
[state_list[i][-1] for i in range(batch_size)], dim=0
)
batch_states.append(processed_lens) batch_states.append(processed_lens)
return batch_states return batch_states
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
for layer in range(tot_num_layers): for layer in range(tot_num_layers):
layer_offset = layer * 6 layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim) # cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk( cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
chunks=batch_size, dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1 chunks=batch_size, dim=1
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
cached_conv2_list[i], cached_conv2_list[i],
] ]
cached_embed_left_pad_list = batch_states[-2].chunk( cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
chunks=batch_size, dim=0
)
for i in range(batch_size): for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i]) state_list[i].append(cached_embed_left_pad_list[i])
@ -404,9 +398,7 @@ def streaming_forward(
new_processed_lens = processed_lens + x_lens new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size) # (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat( src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2] encoder_states = states[:-2]
@ -494,9 +486,7 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
greedy_search( greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device) processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
@ -517,9 +507,7 @@ def decode_one_chunk(
num_active_paths=params.num_active_paths, num_active_paths=params.num_active_paths,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
f"Unsupported decoding method: {params.decoding_method}"
)
states = unstack_states(new_states) states = unstack_states(new_states)
@ -577,9 +565,7 @@ def decode_dataset(
decode_streams = [] decode_streams = []
for num, cut in enumerate(cuts): for num, cut in enumerate(cuts):
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
initial_states = get_init_states( initial_states = get_init_states(model=model, batch_size=1, device=device)
model=model, batch_size=1, device=device
)
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id, cut_id=cut.id,
@ -649,9 +635,7 @@ def decode_dataset(
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}
@ -684,8 +668,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -718,9 +701,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal assert params.causal, params.causal
assert ( assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert ( assert (
"," not in params.left_context_frames "," not in params.left_context_frames
), "left_context_frames should be one value in decoding." ), "left_context_frames should be one value in decoding."
@ -760,9 +741,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -789,9 +770,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"

View File

@ -606,11 +606,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -810,17 +810,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss