mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Rename Conformer to Zipformer
This commit is contained in:
parent
3f05e47447
commit
5dfa141ca5
@ -94,35 +94,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="7,7",
|
||||
help="Number of conformer encoder layers, comma separated.",
|
||||
help="Number of zipformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feedforward-dims",
|
||||
type=str,
|
||||
default="1536,1536",
|
||||
help="Feedforward dimension of the conformer encoder layers, comma separated.",
|
||||
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=str,
|
||||
default="8,8",
|
||||
help="Number of attention heads in the conformer encoder layers.",
|
||||
help="Number of attention heads in the zipformer encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="384,384",
|
||||
help="Embedding dimension in the 2 blocks of conformer encoder layers, comma separated"
|
||||
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dims",
|
||||
type=str,
|
||||
default="192,192",
|
||||
help="""Attention dimension in the 2 blocks of conformer encoder layers, comma separated;
|
||||
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
|
||||
not the same as embedding dimension."""
|
||||
)
|
||||
|
||||
@ -136,7 +136,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-subsampling-factor",
|
||||
"--zipformer-subsampling-factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Subsampling factor for 2nd stack of encoder layers.",
|
||||
@ -419,7 +419,7 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for conformer
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"warm_step": 2000,
|
||||
@ -431,13 +431,13 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(',')))
|
||||
encoder = Conformer(
|
||||
encoder = Zipformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
conformer_subsampling_factor=params.conformer_subsampling_factor,
|
||||
zipformer_subsampling_factor=params.zipformer_subsampling_factor,
|
||||
d_model=to_int_tuple(params.encoder_dims),
|
||||
attention_dim=to_int_tuple(params.attention_dims),
|
||||
encoder_unmasked_dim=params.encoder_unmasked_dim,
|
||||
@ -618,7 +618,7 @@ def compute_loss(
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of Conformer in our case.
|
||||
The model for training. It is an instance of Zipformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
|
||||
@ -63,7 +63,7 @@ class Zipformer(EncoderInterface):
|
||||
self,
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
conformer_subsampling_factor: int = 4,
|
||||
zipformer_subsampling_factor: int = 4,
|
||||
d_model: Tuple[int] = (384, 384),
|
||||
attention_dim: Tuple[int] = (256, 256),
|
||||
encoder_unmasked_dim: int = 256,
|
||||
@ -81,7 +81,7 @@ class Zipformer(EncoderInterface):
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim
|
||||
assert 0 < d_model[0] <= d_model[1]
|
||||
self.d_model = d_model
|
||||
self.conformer_subsampling_factor = conformer_subsampling_factor
|
||||
self.zipformer_subsampling_factor = zipformer_subsampling_factor
|
||||
|
||||
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
|
||||
|
||||
@ -134,7 +134,7 @@ class Zipformer(EncoderInterface):
|
||||
),
|
||||
input_dim=d_model[0],
|
||||
output_dim=d_model[1],
|
||||
downsample=conformer_subsampling_factor,
|
||||
downsample=zipformer_subsampling_factor,
|
||||
)
|
||||
|
||||
self.out_combiner = SimpleCombiner(d_model[0],
|
||||
@ -152,7 +152,7 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
We generate the random masks at this level because we want the 2 masks to 'agree'
|
||||
all the way up the encoder stack. This will mean that the 1st mask will have
|
||||
mask values repeated self.conformer_subsampling_factor times.
|
||||
mask values repeated self.zipformer_subsampling_factor times.
|
||||
|
||||
Args:
|
||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||
@ -164,7 +164,7 @@ class Zipformer(EncoderInterface):
|
||||
d_model0, d_model1 = self.d_model
|
||||
(num_frames0, batch_size, _d_model0) = x.shape
|
||||
assert d_model0 == _d_model0
|
||||
ds = self.conformer_subsampling_factor
|
||||
ds = self.zipformer_subsampling_factor
|
||||
num_frames1 = ((num_frames0 + ds - 1) // ds)
|
||||
|
||||
# on this proportion of the frames, drop out the extra features above
|
||||
@ -380,9 +380,9 @@ class ZipformerEncoder(nn.Module):
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> conformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = conformer_encoder(src)
|
||||
>>> out = zipformer_encoder(src)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -555,7 +555,7 @@ class ZipformerEncoder(nn.Module):
|
||||
|
||||
class DownsampledZipformerEncoder(nn.Module):
|
||||
r"""
|
||||
DownsampledZipformerEncoder is a conformer encoder evaluated at a reduced frame rate,
|
||||
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output
|
||||
so that the output has the same shape as the input.
|
||||
"""
|
||||
@ -1298,7 +1298,7 @@ class FeedforwardModule(nn.Module):
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Zipformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
@ -1509,7 +1509,7 @@ class AttentionCombine(nn.Module):
|
||||
to the identity transform.
|
||||
|
||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||
conformer layers. This has a similar effect as iterated loss. (See:
|
||||
zipformer layers. This has a similar effect as iterated loss. (See:
|
||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||
NETWORKS).
|
||||
"""
|
||||
@ -1634,7 +1634,7 @@ def _test_random_combine():
|
||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||
|
||||
|
||||
def _test_conformer_main():
|
||||
def _test_zipformer_main():
|
||||
feature_dim = 50
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
@ -1665,4 +1665,4 @@ if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_random_combine()
|
||||
_test_conformer_main()
|
||||
_test_zipformer_main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user