Rename Conformer to Zipformer

This commit is contained in:
Daniel Povey 2022-10-27 22:43:46 +08:00
parent 3f05e47447
commit 5dfa141ca5
2 changed files with 23 additions and 23 deletions

View File

@ -94,35 +94,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--num-encoder-layers",
type=str,
default="7,7",
help="Number of conformer encoder layers, comma separated.",
help="Number of zipformer encoder layers, comma separated.",
)
parser.add_argument(
"--feedforward-dims",
type=str,
default="1536,1536",
help="Feedforward dimension of the conformer encoder layers, comma separated.",
help="Feedforward dimension of the zipformer encoder layers, comma separated.",
)
parser.add_argument(
"--nhead",
type=str,
default="8,8",
help="Number of attention heads in the conformer encoder layers.",
help="Number of attention heads in the zipformer encoder layers.",
)
parser.add_argument(
"--encoder-dims",
type=str,
default="384,384",
help="Embedding dimension in the 2 blocks of conformer encoder layers, comma separated"
help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
)
parser.add_argument(
"--attention-dims",
type=str,
default="192,192",
help="""Attention dimension in the 2 blocks of conformer encoder layers, comma separated;
help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension."""
)
@ -136,7 +136,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--conformer-subsampling-factor",
"--zipformer-subsampling-factor",
type=int,
default=2,
help="Subsampling factor for 2nd stack of encoder layers.",
@ -419,7 +419,7 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
# parameters for zipformer
"feature_dim": 80,
"subsampling_factor": 4,
"warm_step": 2000,
@ -431,13 +431,13 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
# TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
encoder = Conformer(
encoder = Zipformer(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
conformer_subsampling_factor=params.conformer_subsampling_factor,
zipformer_subsampling_factor=params.zipformer_subsampling_factor,
d_model=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dim=params.encoder_unmasked_dim,
@ -618,7 +618,7 @@ def compute_loss(
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.

View File

@ -63,7 +63,7 @@ class Zipformer(EncoderInterface):
self,
num_features: int,
subsampling_factor: int = 4,
conformer_subsampling_factor: int = 4,
zipformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (384, 384),
attention_dim: Tuple[int] = (256, 256),
encoder_unmasked_dim: int = 256,
@ -81,7 +81,7 @@ class Zipformer(EncoderInterface):
self.encoder_unmasked_dim = encoder_unmasked_dim
assert 0 < d_model[0] <= d_model[1]
self.d_model = d_model
self.conformer_subsampling_factor = conformer_subsampling_factor
self.zipformer_subsampling_factor = zipformer_subsampling_factor
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
@ -134,7 +134,7 @@ class Zipformer(EncoderInterface):
),
input_dim=d_model[0],
output_dim=d_model[1],
downsample=conformer_subsampling_factor,
downsample=zipformer_subsampling_factor,
)
self.out_combiner = SimpleCombiner(d_model[0],
@ -152,7 +152,7 @@ class Zipformer(EncoderInterface):
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.conformer_subsampling_factor times.
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
@ -164,7 +164,7 @@ class Zipformer(EncoderInterface):
d_model0, d_model1 = self.d_model
(num_frames0, batch_size, _d_model0) = x.shape
assert d_model0 == _d_model0
ds = self.conformer_subsampling_factor
ds = self.zipformer_subsampling_factor
num_frames1 = ((num_frames0 + ds - 1) // ds)
# on this proportion of the frames, drop out the extra features above
@ -380,9 +380,9 @@ class ZipformerEncoder(nn.Module):
Examples::
>>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
>>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = conformer_encoder(src)
>>> out = zipformer_encoder(src)
"""
def __init__(
self,
@ -555,7 +555,7 @@ class ZipformerEncoder(nn.Module):
class DownsampledZipformerEncoder(nn.Module):
r"""
DownsampledZipformerEncoder is a conformer encoder evaluated at a reduced frame rate,
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output
so that the output has the same shape as the input.
"""
@ -1298,7 +1298,7 @@ class FeedforwardModule(nn.Module):
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Zipformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
Args:
channels (int): The number of channels of conv layers.
@ -1509,7 +1509,7 @@ class AttentionCombine(nn.Module):
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
zipformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
@ -1634,7 +1634,7 @@ def _test_random_combine():
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_conformer_main():
def _test_zipformer_main():
feature_dim = 50
batch_size = 5
seq_len = 20
@ -1665,4 +1665,4 @@ if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_random_combine()
_test_conformer_main()
_test_zipformer_main()