diff --git a/egs/librispeech/ASR/transducer_emformer/noam.py b/egs/librispeech/ASR/transducer_emformer/noam.py new file mode 100644 index 000000000..e46bf35fb --- /dev/null +++ b/egs/librispeech/ASR/transducer_emformer/noam.py @@ -0,0 +1,104 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/transducer_emformer/train.py b/egs/librispeech/ASR/transducer_emformer/train.py index 1f52370fd..175387e03 100755 --- a/egs/librispeech/ASR/transducer_emformer/train.py +++ b/egs/librispeech/ASR/transducer_emformer/train.py @@ -21,11 +21,11 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless/train.py \ +./transducer_emformer/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless/exp \ + --exp-dir transducer_emformer/exp \ --full-libri 1 \ --max-duration 300 """ @@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple @@ -43,18 +44,18 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer from decoder import Decoder +from emformer import Emformer from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from noam import Noam from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from transformer import Noam from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -111,7 +112,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt + transducer_emformer/exp/epoch-{start_epoch-1}.pt """, ) @@ -127,7 +128,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="transducer_emformer/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -279,7 +280,7 @@ def get_params() -> AttributeDict: "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 "log_diagnostics": False, - # parameters for conformer + # parameters for Emformer "feature_dim": 80, "subsampling_factor": 4, "attention_dim": 512, @@ -287,10 +288,13 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, + "left_context_length": 120, # 120 frames + "segment_length": 16, + "right_context_length": 4, # parameters for decoder "embedding_dim": 512, # parameters for Noam - "warm_step": 80000, # For the 100h subset, use 30000 + "warm_step": 80000, # For the 100h subset, use 20000 "env_info": get_env_info(), } ) @@ -299,8 +303,7 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( + encoder = Emformer( num_features=params.feature_dim, output_dim=params.vocab_size, subsampling_factor=params.subsampling_factor, @@ -309,6 +312,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, + left_context_length=params.left_context_length, + segment_length=params.segment_length, + right_context_length=params.right_context_length, ) return encoder @@ -496,7 +502,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -725,7 +735,7 @@ def run(rank, world_size, args): params.update(vars(args)) if params.full_libri is False: params.valid_interval = 800 - params.warm_step = 30000 + params.warm_step = 20000 fix_random_seed(params.seed) if world_size > 1: