Rename Conformer to Zipformer

This commit is contained in:
Daniel Povey 2022-10-27 22:43:46 +08:00
parent 3f05e47447
commit 5dfa141ca5
2 changed files with 23 additions and 23 deletions

View File

@ -94,35 +94,35 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="7,7", default="7,7",
help="Number of conformer encoder layers, comma separated.", help="Number of zipformer encoder layers, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--feedforward-dims", "--feedforward-dims",
type=str, type=str,
default="1536,1536", default="1536,1536",
help="Feedforward dimension of the conformer encoder layers, comma separated.", help="Feedforward dimension of the zipformer encoder layers, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--nhead", "--nhead",
type=str, type=str,
default="8,8", default="8,8",
help="Number of attention heads in the conformer encoder layers.", help="Number of attention heads in the zipformer encoder layers.",
) )
parser.add_argument( parser.add_argument(
"--encoder-dims", "--encoder-dims",
type=str, type=str,
default="384,384", default="384,384",
help="Embedding dimension in the 2 blocks of conformer encoder layers, comma separated" help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated"
) )
parser.add_argument( parser.add_argument(
"--attention-dims", "--attention-dims",
type=str, type=str,
default="192,192", default="192,192",
help="""Attention dimension in the 2 blocks of conformer encoder layers, comma separated; help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
not the same as embedding dimension.""" not the same as embedding dimension."""
) )
@ -136,7 +136,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
) )
parser.add_argument( parser.add_argument(
"--conformer-subsampling-factor", "--zipformer-subsampling-factor",
type=int, type=int,
default=2, default=2,
help="Subsampling factor for 2nd stack of encoder layers.", help="Subsampling factor for 2nd stack of encoder layers.",
@ -419,7 +419,7 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer # parameters for zipformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"warm_step": 2000, "warm_step": 2000,
@ -431,13 +431,13 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str): def to_int_tuple(s: str):
return tuple(map(int, s.split(','))) return tuple(map(int, s.split(',')))
encoder = Conformer( encoder = Zipformer(
num_features=params.feature_dim, num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
conformer_subsampling_factor=params.conformer_subsampling_factor, zipformer_subsampling_factor=params.zipformer_subsampling_factor,
d_model=to_int_tuple(params.encoder_dims), d_model=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims), attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dim=params.encoder_unmasked_dim, encoder_unmasked_dim=params.encoder_unmasked_dim,
@ -618,7 +618,7 @@ def compute_loss(
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
model: model:
The model for training. It is an instance of Conformer in our case. The model for training. It is an instance of Zipformer in our case.
batch: batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it. for the content in it.

View File

@ -63,7 +63,7 @@ class Zipformer(EncoderInterface):
self, self,
num_features: int, num_features: int,
subsampling_factor: int = 4, subsampling_factor: int = 4,
conformer_subsampling_factor: int = 4, zipformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (384, 384), d_model: Tuple[int] = (384, 384),
attention_dim: Tuple[int] = (256, 256), attention_dim: Tuple[int] = (256, 256),
encoder_unmasked_dim: int = 256, encoder_unmasked_dim: int = 256,
@ -81,7 +81,7 @@ class Zipformer(EncoderInterface):
self.encoder_unmasked_dim = encoder_unmasked_dim self.encoder_unmasked_dim = encoder_unmasked_dim
assert 0 < d_model[0] <= d_model[1] assert 0 < d_model[0] <= d_model[1]
self.d_model = d_model self.d_model = d_model
self.conformer_subsampling_factor = conformer_subsampling_factor self.zipformer_subsampling_factor = zipformer_subsampling_factor
assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1] assert encoder_unmasked_dim <= d_model[0] and encoder_unmasked_dim <= d_model[1]
@ -134,7 +134,7 @@ class Zipformer(EncoderInterface):
), ),
input_dim=d_model[0], input_dim=d_model[0],
output_dim=d_model[1], output_dim=d_model[1],
downsample=conformer_subsampling_factor, downsample=zipformer_subsampling_factor,
) )
self.out_combiner = SimpleCombiner(d_model[0], self.out_combiner = SimpleCombiner(d_model[0],
@ -152,7 +152,7 @@ class Zipformer(EncoderInterface):
We generate the random masks at this level because we want the 2 masks to 'agree' We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.conformer_subsampling_factor times. mask values repeated self.zipformer_subsampling_factor times.
Args: Args:
x: the embeddings (needed for the shape and dtype and device), of shape x: the embeddings (needed for the shape and dtype and device), of shape
@ -164,7 +164,7 @@ class Zipformer(EncoderInterface):
d_model0, d_model1 = self.d_model d_model0, d_model1 = self.d_model
(num_frames0, batch_size, _d_model0) = x.shape (num_frames0, batch_size, _d_model0) = x.shape
assert d_model0 == _d_model0 assert d_model0 == _d_model0
ds = self.conformer_subsampling_factor ds = self.zipformer_subsampling_factor
num_frames1 = ((num_frames0 + ds - 1) // ds) num_frames1 = ((num_frames0 + ds - 1) // ds)
# on this proportion of the frames, drop out the extra features above # on this proportion of the frames, drop out the extra features above
@ -380,9 +380,9 @@ class ZipformerEncoder(nn.Module):
Examples:: Examples::
>>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512) >>> src = torch.rand(10, 32, 512)
>>> out = conformer_encoder(src) >>> out = zipformer_encoder(src)
""" """
def __init__( def __init__(
self, self,
@ -555,7 +555,7 @@ class ZipformerEncoder(nn.Module):
class DownsampledZipformerEncoder(nn.Module): class DownsampledZipformerEncoder(nn.Module):
r""" r"""
DownsampledZipformerEncoder is a conformer encoder evaluated at a reduced frame rate, DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output after convolutional downsampling, and then upsampled again at the output
so that the output has the same shape as the input. so that the output has the same shape as the input.
""" """
@ -1298,7 +1298,7 @@ class FeedforwardModule(nn.Module):
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Zipformer model. """ConvolutionModule in Zipformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
Args: Args:
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
@ -1509,7 +1509,7 @@ class AttentionCombine(nn.Module):
to the identity transform. to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See: zipformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS). NETWORKS).
""" """
@ -1634,7 +1634,7 @@ def _test_random_combine():
assert torch.allclose(y, x[0]) # .. since actually all ones. assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_conformer_main(): def _test_zipformer_main():
feature_dim = 50 feature_dim = 50
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
@ -1665,4 +1665,4 @@ if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_random_combine() _test_random_combine()
_test_conformer_main() _test_zipformer_main()