mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fixes, config changes
This commit is contained in:
parent
d3b3592986
commit
bac72718f0
@ -123,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
default="2,2,4,6,4,2",
|
default="2,4,4,4,4,4",
|
||||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="384,768,1024,1536,1024,768",
|
default="384,512,1024,1536,1024,512",
|
||||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="192,256,320,384,320,256",
|
default="192,192,256,320,256,192",
|
||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-unmasked-dim",
|
"--encoder-unmasked-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="164,192,256,256,256,192",
|
default="192,192,224,224,224,192",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -215,6 +215,7 @@ class Zipformer(EncoderInterface):
|
|||||||
encoder.lr_scale = downsampling_factor[i] ** -0.33
|
encoder.lr_scale = downsampling_factor[i] ** -0.33
|
||||||
|
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
# initializes self.skip_layers and self.skip_modules
|
# initializes self.skip_layers and self.skip_modules
|
||||||
@ -327,8 +328,12 @@ class Zipformer(EncoderInterface):
|
|||||||
- lengths, a tensor of shape (batch_size,) containing the number
|
- lengths, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
|
logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
|
logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -358,6 +363,7 @@ class Zipformer(EncoderInterface):
|
|||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||||
outputs.append(x)
|
outputs.append(x)
|
||||||
|
logging.info(f"Memory allocated after stack {i}: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
x = self.downsample_output(x)
|
x = self.downsample_output(x)
|
||||||
# class Downsample has this rounding behavior..
|
# class Downsample has this rounding behavior..
|
||||||
@ -834,6 +840,7 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.extra_proj = None
|
self.extra_proj = None
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor) -> Tensor:
|
src: Tensor) -> Tensor:
|
||||||
@ -867,6 +874,8 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
if self.extra_proj is not None:
|
if self.extra_proj is not None:
|
||||||
ans2 = self.extra_proj(src)
|
ans2 = self.extra_proj(src)
|
||||||
ans = torch.cat((ans, ans2), dim=2)
|
ans = torch.cat((ans, ans2), dim=2)
|
||||||
|
|
||||||
|
ans = ans[..., :self.out_channels]
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -941,7 +950,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
dtype=src1.dtype)),
|
dtype=src1.dtype)),
|
||||||
dim=-1)
|
dim=-1)
|
||||||
else:
|
else:
|
||||||
src1 = src1[:src2_dim]
|
src1 = src1[...,:src2_dim]
|
||||||
|
|
||||||
src1 = src1 * weight1
|
src1 = src1 * weight1
|
||||||
src2 = src2 * (1.0 - weight1)
|
src2 = src2 * (1.0 - weight1)
|
||||||
@ -1917,7 +1926,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
out_channels: int,
|
out_channels: int,
|
||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 96,
|
layer3_channels: int = 64,
|
||||||
dropout: FloatLike = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user