fixed formatting issues

This commit is contained in:
jinzr 2023-09-24 17:02:01 +08:00
parent 78b2279969
commit 39cf318ba8
8 changed files with 447 additions and 400 deletions

View File

@ -230,7 +230,9 @@ class Conformer(Transformer):
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F) ) # (T, B, F)
else: else:
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F) x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
if self.normalize_before: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)

View File

@ -61,10 +61,15 @@ class Decoder(nn.Module):
) )
# the balancers are to avoid any drift in the magnitude of the # the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging. # embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1, self.balancer = Balancer(
min_positive=0.0, max_positive=1.0, decoder_dim,
min_abs=0.5, max_abs=1.0, channel_dim=-1,
prob=0.05) min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id self.blank_id = blank_id
@ -81,10 +86,15 @@ class Decoder(nn.Module):
groups=decoder_dim // 4, # group size == 4 groups=decoder_dim // 4, # group size == 4
bias=False, bias=False,
) )
self.balancer2 = Balancer(decoder_dim, channel_dim=-1, self.balancer2 = Balancer(
min_positive=0.0, max_positive=1.0, decoder_dim,
min_abs=0.5, max_abs=1.0, channel_dim=-1,
prob=0.05) min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
@ -107,9 +117,7 @@ class Decoder(nn.Module):
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embedding_out = F.pad( embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
embedding_out, pad=(self.context_size - 1, 0)
)
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output

View File

@ -52,12 +52,13 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj( logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
decoder_out
)
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -303,7 +303,9 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time() start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time() end_time = time.time()
elapsed_seconds = end_time - start_time elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration rtf = elapsed_seconds / total_duration

View File

@ -100,17 +100,13 @@ class Model(nn.Module):
self.encoder_embed = encoder_embed self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj self.encoder_proj = encoder_proj
def forward( def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
self, feature: Tensor, feature_lens: Tensor
) -> Tuple[Tensor, Tensor]:
x, x_lens = self.encoder_embed(feature, feature_lens) x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder( encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
logits = self.encoder_proj(encoder_out) logits = self.encoder_proj(encoder_out)
@ -168,9 +164,7 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
formatter = ( formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
main() main()

File diff suppressed because it is too large Load Diff

View File

@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
) )
batch_states.append(cached_embed_left_pad) batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat( processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
[state_list[i][-1] for i in range(batch_size)], dim=0
)
batch_states.append(processed_lens) batch_states.append(processed_lens)
return batch_states return batch_states
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
for layer in range(tot_num_layers): for layer in range(tot_num_layers):
layer_offset = layer * 6 layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim) # cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk( cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
chunks=batch_size, dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1 chunks=batch_size, dim=1
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
cached_conv2_list[i], cached_conv2_list[i],
] ]
cached_embed_left_pad_list = batch_states[-2].chunk( cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
chunks=batch_size, dim=0
)
for i in range(batch_size): for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i]) state_list[i].append(cached_embed_left_pad_list[i])
@ -380,11 +374,7 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states. Returns encoder outputs, output lengths, and updated states.
""" """
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
( (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,
@ -404,9 +394,7 @@ def streaming_forward(
new_processed_lens = processed_lens + x_lens new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size) # (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat( src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2] encoder_states = states[:-2]
@ -494,9 +482,7 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
greedy_search( greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device) processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
@ -517,9 +503,7 @@ def decode_one_chunk(
num_active_paths=params.num_active_paths, num_active_paths=params.num_active_paths,
) )
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
f"Unsupported decoding method: {params.decoding_method}"
)
states = unstack_states(new_states) states = unstack_states(new_states)
@ -577,9 +561,7 @@ def decode_dataset(
decode_streams = [] decode_streams = []
for num, cut in enumerate(cuts): for num, cut in enumerate(cuts):
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
initial_states = get_init_states( initial_states = get_init_states(model=model, batch_size=1, device=device)
model=model, batch_size=1, device=device
)
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id, cut_id=cut.id,
@ -649,9 +631,7 @@ def decode_dataset(
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError( raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results} return {key: decode_results}
@ -684,8 +664,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -718,9 +697,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal assert params.causal, params.causal
assert ( assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert ( assert (
"," not in params.left_context_frames "," not in params.left_context_frames
), "left_context_frames should be one value in decoding." ), "left_context_frames should be one value in decoding."
@ -760,9 +737,9 @@ def main():
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg
)[: params.avg] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -789,9 +766,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints( filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
params.exp_dir, iteration=-params.iter : params.avg + 1
)[: params.avg + 1] ]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"

View File

@ -107,9 +107,7 @@ class ConvNeXt(nn.Module):
if layerdrop_rate != 0.0: if layerdrop_rate != 0.0:
batch_size = x.shape[0] batch_size = x.shape[0]
mask = ( mask = (
torch.rand( torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
> layerdrop_rate > layerdrop_rate
) )
else: else:
@ -278,9 +276,7 @@ class Conv2dSubsampling(nn.Module):
# many copies of this extra gradient term. # many copies of this extra gradient term.
self.out_whiten = Whiten( self.out_whiten = Whiten(
num_groups=1, num_groups=1,
whitening_limit=ScheduledFloat( whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.02, grad_scale=0.02,
) )
@ -331,7 +327,7 @@ class Conv2dSubsampling(nn.Module):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2 x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens return x, x_lens
@ -403,8 +399,8 @@ class Conv2dSubsampling(nn.Module):
left_pad = self.convnext.padding[0] left_pad = self.convnext.padding[0]
freq = self.out_width freq = self.out_width
channels = self.layer3_channels channels = self.layer3_channels
cached_embed_left_pad = torch.zeros( cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
batch_size, channels, left_pad, freq device
).to(device) )
return cached_embed_left_pad return cached_embed_left_pad