mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix typo and reformatted zipformer
This commit is contained in:
parent
82f34a2388
commit
208839fb9b
@ -273,8 +273,7 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -371,9 +370,7 @@ def decode_one_batch(
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||||
x, x_lens, src_key_padding_mask
|
|
||||||
)
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
@ -433,10 +430,7 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
params.decoding_method == "greedy_search"
|
|
||||||
and params.max_sym_per_frame == 1
|
|
||||||
):
|
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -567,9 +561,7 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -602,8 +594,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -664,9 +655,7 @@ def main():
|
|||||||
if "LG" in params.decoding_method:
|
if "LG" in params.decoding_method:
|
||||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -698,9 +687,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -727,9 +716,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -788,9 +777,7 @@ def main():
|
|||||||
decoding_graph.scores *= params.ngram_lm_scale
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
else:
|
else:
|
||||||
word_table = None
|
word_table = None
|
||||||
decoding_graph = k2.trivial_graph(
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
params.vocab_size - 1, device=device
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
word_table = None
|
||||||
|
|||||||
@ -62,10 +62,15 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
# the balancers are to avoid any drift in the magnitude of the
|
# the balancers are to avoid any drift in the magnitude of the
|
||||||
# embeddings, which would interact badly with parameter averaging.
|
# embeddings, which would interact badly with parameter averaging.
|
||||||
self.balancer = Balancer(decoder_dim, channel_dim=-1,
|
self.balancer = Balancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
decoder_dim,
|
||||||
min_abs=0.5, max_abs=1.0,
|
channel_dim=-1,
|
||||||
prob=0.05)
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.5,
|
||||||
|
max_abs=1.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
@ -82,10 +87,15 @@ class Decoder(nn.Module):
|
|||||||
groups=decoder_dim // 4, # group size == 4
|
groups=decoder_dim // 4, # group size == 4
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
|
self.balancer2 = Balancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
decoder_dim,
|
||||||
min_abs=0.5, max_abs=1.0,
|
channel_dim=-1,
|
||||||
prob=0.05)
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.5,
|
||||||
|
max_abs=1.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -108,9 +118,7 @@ class Decoder(nn.Module):
|
|||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embedding_out = F.pad(
|
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
|
|||||||
@ -257,6 +257,7 @@ def get_parser():
|
|||||||
|
|
||||||
class EncoderModel(nn.Module):
|
class EncoderModel(nn.Module):
|
||||||
"""A wrapper for encoder and encoder_embed"""
|
"""A wrapper for encoder and encoder_embed"""
|
||||||
|
|
||||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
@ -275,9 +276,7 @@ class EncoderModel(nn.Module):
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
x, x_lens, src_key_padding_mask
|
|
||||||
)
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|||||||
@ -52,12 +52,13 @@ class Joiner(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
|
assert encoder_out.ndim == decoder_out.ndim, (
|
||||||
|
encoder_out.shape,
|
||||||
|
decoder_out.shape,
|
||||||
|
)
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
decoder_out
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
|
|||||||
|
|
||||||
yield tuples # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
):
|
):
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch = True
|
batch = True
|
||||||
|
|
||||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
|
|
||||||
with self.batched_params(group["params"], group_params_names) as batches:
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
# a stacking dim, it is not a real dim.
|
# a stacking dim, it is not a real dim.
|
||||||
@ -428,7 +425,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for (p, state, param_names) in tuples:
|
for p, state, param_names in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -513,7 +510,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
from tuples, we still pass it to save some time.
|
from tuples, we still pass it to save some time.
|
||||||
"""
|
"""
|
||||||
all_sumsq_orig = {}
|
all_sumsq_orig = {}
|
||||||
for (p, state, batch_param_names) in tuples:
|
for p, state, batch_param_names in tuples:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
@ -529,7 +526,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
for name, sumsq_orig, rms, grad in zip(
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||||
):
|
):
|
||||||
|
|
||||||
proportion_orig = sumsq_orig / tot_sumsq
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
@ -667,8 +663,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# We have to look at the trained model for parameters at or around the
|
# We have to look at the trained model for parameters at or around the
|
||||||
# param_max_rms, because sometimes they can indicate a problem with the
|
# param_max_rms, because sometimes they can indicate a problem with the
|
||||||
# topology or settings.
|
# topology or settings.
|
||||||
scale_step = torch.minimum(scale_step,
|
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||||
(param_max_rms - param_rms) / param_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
@ -879,7 +874,8 @@ class Eden(LRScheduler):
|
|||||||
warmup_factor = (
|
warmup_factor = (
|
||||||
1.0
|
1.0
|
||||||
if self.batch >= self.warmup_batches
|
if self.batch >= self.warmup_batches
|
||||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
else self.warmup_start
|
||||||
|
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -323,9 +323,7 @@ def main():
|
|||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||||
x, x_lens, src_key_padding_mask
|
|
||||||
)
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|||||||
@ -100,17 +100,13 @@ class Model(nn.Module):
|
|||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder_proj = encoder_proj
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
def forward(
|
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
self, feature: Tensor, feature_lens: Tensor
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
x, x_lens, src_key_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
logits = self.encoder_proj(encoder_out)
|
logits = self.encoder_proj(encoder_out)
|
||||||
@ -168,9 +164,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -81,7 +81,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 0.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
|||||||
)
|
)
|
||||||
batch_states.append(cached_embed_left_pad)
|
batch_states.append(cached_embed_left_pad)
|
||||||
|
|
||||||
processed_lens = torch.cat(
|
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||||
[state_list[i][-1] for i in range(batch_size)], dim=0
|
|
||||||
)
|
|
||||||
batch_states.append(processed_lens)
|
batch_states.append(processed_lens)
|
||||||
|
|
||||||
return batch_states
|
return batch_states
|
||||||
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|||||||
for layer in range(tot_num_layers):
|
for layer in range(tot_num_layers):
|
||||||
layer_offset = layer * 6
|
layer_offset = layer * 6
|
||||||
# cached_key: (left_context_len, batch_size, key_dim)
|
# cached_key: (left_context_len, batch_size, key_dim)
|
||||||
cached_key_list = batch_states[layer_offset].chunk(
|
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||||
chunks=batch_size, dim=1
|
|
||||||
)
|
|
||||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||||
chunks=batch_size, dim=1
|
chunks=batch_size, dim=1
|
||||||
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|||||||
cached_conv2_list[i],
|
cached_conv2_list[i],
|
||||||
]
|
]
|
||||||
|
|
||||||
cached_embed_left_pad_list = batch_states[-2].chunk(
|
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||||
chunks=batch_size, dim=0
|
|
||||||
)
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
state_list[i].append(cached_embed_left_pad_list[i])
|
state_list[i].append(cached_embed_left_pad_list[i])
|
||||||
|
|
||||||
@ -404,9 +398,7 @@ def streaming_forward(
|
|||||||
new_processed_lens = processed_lens + x_lens
|
new_processed_lens = processed_lens + x_lens
|
||||||
|
|
||||||
# (batch, left_context_size + chunk_size)
|
# (batch, left_context_size + chunk_size)
|
||||||
src_key_padding_mask = torch.cat(
|
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||||
[processed_mask, src_key_padding_mask], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
encoder_states = states[:-2]
|
encoder_states = states[:-2]
|
||||||
@ -494,9 +486,7 @@ def decode_one_chunk(
|
|||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
greedy_search(
|
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||||
model=model, encoder_out=encoder_out, streams=decode_streams
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
processed_lens = torch.tensor(processed_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
processed_lens = processed_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
@ -517,9 +507,7 @@ def decode_one_chunk(
|
|||||||
num_active_paths=params.num_active_paths,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
|
|
||||||
states = unstack_states(new_states)
|
states = unstack_states(new_states)
|
||||||
|
|
||||||
@ -577,9 +565,7 @@ def decode_dataset(
|
|||||||
decode_streams = []
|
decode_streams = []
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
initial_states = get_init_states(
|
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||||
model=model, batch_size=1, device=device
|
|
||||||
)
|
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
cut_id=cut.id,
|
cut_id=cut.id,
|
||||||
@ -649,9 +635,7 @@ def decode_dataset(
|
|||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"num_active_paths_{params.num_active_paths}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
@ -684,8 +668,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -718,9 +701,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
assert params.causal, params.causal
|
assert params.causal, params.causal
|
||||||
assert (
|
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||||
"," not in params.chunk_size
|
|
||||||
), "chunk_size should be one value in decoding."
|
|
||||||
assert (
|
assert (
|
||||||
"," not in params.left_context_frames
|
"," not in params.left_context_frames
|
||||||
), "left_context_frames should be one value in decoding."
|
), "left_context_frames should be one value in decoding."
|
||||||
@ -760,9 +741,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -789,9 +770,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
|
|||||||
@ -107,9 +107,7 @@ class ConvNeXt(nn.Module):
|
|||||||
if layerdrop_rate != 0.0:
|
if layerdrop_rate != 0.0:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
mask = (
|
mask = (
|
||||||
torch.rand(
|
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
> layerdrop_rate
|
> layerdrop_rate
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -275,9 +273,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# many copies of this extra gradient term.
|
# many copies of this extra gradient term.
|
||||||
self.out_whiten = Whiten(
|
self.out_whiten = Whiten(
|
||||||
num_groups=1,
|
num_groups=1,
|
||||||
whitening_limit=ScheduledFloat(
|
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
|
||||||
),
|
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.02,
|
grad_scale=0.02,
|
||||||
)
|
)
|
||||||
@ -400,8 +396,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
left_pad = self.convnext.padding[0]
|
left_pad = self.convnext.padding[0]
|
||||||
freq = self.out_width
|
freq = self.out_width
|
||||||
channels = self.layer3_channels
|
channels = self.layer3_channels
|
||||||
cached_embed_left_pad = torch.zeros(
|
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||||
batch_size, channels, left_pad, freq
|
device
|
||||||
).to(device)
|
)
|
||||||
|
|
||||||
return cached_embed_left_pad
|
return cached_embed_left_pad
|
||||||
|
|||||||
@ -408,7 +408,7 @@ def get_parser():
|
|||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
end of each epoch where `xxx` is the epoch number counting from 1.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user