diff --git a/egs/aishell/ASR/whisper/model.py b/egs/aishell/ASR/whisper/model.py
new file mode 100644
index 000000000..953b80ff4
--- /dev/null
+++ b/egs/aishell/ASR/whisper/model.py
@@ -0,0 +1,433 @@
+import torch
+import torch.nn as nn
+import base64
+import gzip
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional, Union
+import os
+import urllib
+import hashlib
+import numpy as np
+
+import torch.nn.functional as F
+from torch import Tensor
+
+from whisper.decoding import decode as decode_function
+from whisper.transcribe import transcribe as transcribe_function
+
+
+@dataclass
+class ModelDimensions:
+ n_mels: int
+ n_audio_ctx: int
+ n_audio_state: int
+ n_audio_head: int
+ n_audio_layer: int
+ n_vocab: int
+ n_text_ctx: int
+ n_text_state: int
+ n_text_head: int
+ n_text_layer: int
+
+
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x: Tensor) -> Tensor:
+ return super().forward(x.float()).type(x.dtype)
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(
+ x,
+ self.weight.to(x.dtype),
+ None if self.bias is None else self.bias.to(x.dtype),
+ )
+
+
+class Conv1d(nn.Conv1d):
+ def _conv_forward(
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+ ) -> Tensor:
+ return super()._conv_forward(
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+ )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, n_state: int, n_head: int):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ q = self.query(x)
+
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
+ else:
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
+
+ wv, qk = self.qkv_attention(q, k, v, mask)
+ return self.out(wv), qk
+
+ def qkv_attention(
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+ ):
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head) ** -0.25
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+
+ qk = q @ k
+ if mask is not None:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ qk = qk.float()
+
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+ super().__init__()
+
+ self.attn = MultiHeadAttention(n_state, n_head)
+ self.attn_ln = LayerNorm(n_state)
+
+ self.cross_attn = (
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
+ )
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+ )
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+ if self.cross_attn:
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+class AudioEncoder(nn.Module):
+ def __init__(
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+ ):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+ )
+ self.ln_post = LayerNorm(n_state)
+
+ def forward(self, x: Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+
+ x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
+ #x = (x + self.positional_embedding).to(x.dtype)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+class TextDecoder(nn.Module):
+ def __init__(
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+ ):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+ for _ in range(n_layer)
+ ]
+ )
+ self.ln = LayerNorm(n_state)
+
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+ """
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+ the text tokens
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+ the encoded audio features to be attended on
+ """
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ x = (
+ self.token_embedding(x)
+ + self.positional_embedding[offset : offset + x.shape[-1]]
+ )
+ x = x.to(xa.dtype)
+
+ for block in self.blocks:
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+ x = self.ln(x)
+ logits = (
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+ ).float()
+
+ return logits
+
+
+class Whisper(nn.Module):
+ def __init__(self, dims: ModelDimensions):
+ super().__init__()
+ self.dims = dims
+ self.encoder = AudioEncoder(
+ self.dims.n_mels,
+ self.dims.n_audio_ctx,
+ self.dims.n_audio_state,
+ self.dims.n_audio_head,
+ self.dims.n_audio_layer,
+ )
+ self.decoder = TextDecoder(
+ self.dims.n_vocab,
+ self.dims.n_text_ctx,
+ self.dims.n_text_state,
+ self.dims.n_text_head,
+ self.dims.n_text_layer,
+ )
+ # use the last half layers for alignment by default; see `set_alignment_heads()` below
+ all_heads = torch.zeros(
+ self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
+ )
+ all_heads[self.dims.n_text_layer // 2 :] = True
+ self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+
+ def set_alignment_heads(self, dump: bytes):
+ array = np.frombuffer(
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
+ ).copy()
+ mask = torch.from_numpy(array).reshape(
+ self.dims.n_text_layer, self.dims.n_text_head
+ )
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
+
+ def embed_audio(self, mel: torch.Tensor):
+ return self.encoder(mel)
+
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+ return self.decoder(tokens, audio_features)
+
+ def forward(
+ self, mel: torch.Tensor, tokens: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ return self.decoder(tokens, self.encoder(mel))
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def is_multilingual(self):
+ return self.dims.n_vocab == 51865
+
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+ """
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+ tensors calculated for the previous positions. This method returns a dictionary that stores
+ all caches, and the necessary hooks for the key and value projection modules that save the
+ intermediate tensors to be reused during later calculations.
+
+ Returns
+ -------
+ cache : Dict[nn.Module, torch.Tensor]
+ A dictionary object mapping the key/value projection modules to its cache
+ hooks : List[RemovableHandle]
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
+ """
+ cache = {**cache} if cache is not None else {}
+ hooks = []
+
+ def save_to_cache(module, _, output):
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
+ # save as-is, for the first token or cross attention
+ cache[module] = output
+ else:
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
+ return cache[module]
+
+ def install_hooks(layer: nn.Module):
+ if isinstance(layer, MultiHeadAttention):
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
+
+ self.decoder.apply(install_hooks)
+ return cache, hooks
+
+ #detect_language = detect_language_function
+ transcribe = transcribe_function
+ decode = decode_function
+
+_MODELS = {
+ "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
+ "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
+ "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
+ "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
+ "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
+ "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
+ "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
+ "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
+ "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
+ "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
+ "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
+}
+
+def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
+ os.makedirs(root, exist_ok=True)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, os.path.basename(url))
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ with open(download_target, "rb") as f:
+ model_bytes = f.read()
+ if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
+ return model_bytes if in_memory else download_target
+ else:
+ warnings.warn(
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+ )
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ model_bytes = open(download_target, "rb").read()
+ if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
+ raise RuntimeError(
+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
+ )
+
+ return model_bytes if in_memory else download_target
+
+def load_model(
+ name: str,
+ device: Optional[Union[str, torch.device]] = None,
+ download_root: str = None,
+ in_memory: bool = False,
+) -> Whisper:
+ """
+ Load a Whisper ASR model
+
+ Parameters
+ ----------
+ name : str
+ one of the official model names listed by `whisper.available_models()`, or
+ path to a model checkpoint containing the model dimensions and the model state_dict.
+ device : Union[str, torch.device]
+ the PyTorch device to put the model into
+ download_root: str
+ path to download the model files; by default, it uses "~/.cache/whisper"
+ in_memory: bool
+ whether to preload the model weights into host memory
+
+ Returns
+ -------
+ model : Whisper
+ The Whisper ASR model instance
+ """
+
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if download_root is None:
+ default = os.path.join(os.path.expanduser("~"), ".cache")
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
+
+ if name in _MODELS:
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
+ # alignment_heads = _ALIGNMENT_HEADS[name]
+ alignment_heads = None
+ elif os.path.isfile(name):
+ checkpoint_file = open(name, "rb").read() if in_memory else name
+ alignment_heads = None
+ else:
+ raise RuntimeError(
+ f"Model {name} not found; available models = {available_models()}"
+ )
+
+ with (
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
+ ) as fp:
+ checkpoint = torch.load(fp, map_location=device)
+ del checkpoint_file
+
+ dims = ModelDimensions(**checkpoint["dims"])
+ model = Whisper(dims)
+ model.load_state_dict(checkpoint["model_state_dict"])
+
+ # if alignment_heads is not None:
+ # model.set_alignment_heads(alignment_heads)
+
+ return model.to(device)
+
+
diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py
new file mode 100644
index 000000000..abfb2092c
--- /dev/null
+++ b/egs/aishell/ASR/whisper/optim.py
@@ -0,0 +1,1173 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+import random
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.optim import Optimizer
+
+
+class BatchedOptimizer(Optimizer):
+ """
+ This class adds to class Optimizer the capability to optimize parameters in batches:
+ it will stack the parameters and their grads for you so the optimizer can work
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
+ as it reduces the number of kernels launched in the optimizer.
+
+ Args:
+ params:
+ """
+
+ def __init__(self, params, defaults):
+ super(BatchedOptimizer, self).__init__(params, defaults)
+
+ @contextlib.contextmanager
+ def batched_params(self, param_group, group_params_names):
+ """
+ This function returns (technically, yields) a list of
+ of tuples (p, state), where
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
+ that share the same shape, and its gradient is also stacked;
+ `state` is the state corresponding to this batch of parameters
+ (it will be physically located in the "state" for one of the real
+ parameters, the last one that has any particular shape and dtype).
+
+ This function is decorated as a context manager so that it can
+ write parameters back to their "real" locations.
+
+ The idea is, instead of doing:
+
+ for p in group["params"]:
+ state = self.state[p]
+ ...
+
+ you can do:
+
+ with self.batched_params(group["params"]) as batches:
+ for p, state, p_names in batches:
+ ...
+
+
+ Args:
+ group: a parameter group, which is a list of parameters; should be
+ one of self.param_groups.
+ group_params_names: name for each parameter in group,
+ which is List[str].
+ """
+ batches = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+ batches_names = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
+
+ assert len(param_group) == len(group_params_names)
+ for p, named_p in zip(param_group, group_params_names):
+ key = (str(p.dtype), *p.shape)
+ batches[key].append(p)
+ batches_names[key].append(named_p)
+
+ batches_names_keys = list(batches_names.keys())
+ sorted_idx = sorted(
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
+ )
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
+
+ stacked_params_dict = dict()
+
+ # turn batches into a list, in deterministic order.
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+ # one for each batch in `batches`.
+ tuples = []
+
+ for batch, batch_names in zip(batches, batches_names):
+ p = batch[0]
+ # we arbitrarily store the state in the
+ # state corresponding to the 1st parameter in the
+ # group. class Optimizer will take care of saving/loading state.
+ state = self.state[p]
+ p_stacked = torch.stack(batch)
+ grad = torch.stack(
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
+ )
+ p_stacked.grad = grad
+ stacked_params_dict[key] = p_stacked
+ tuples.append((p_stacked, state, batch_names))
+
+ yield tuples # <-- calling code will do the actual optimization here!
+
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
+ for i, p in enumerate(batch): # batch is list of Parameter
+ p.copy_(stacked_params[i])
+
+
+class ScaledAdam(BatchedOptimizer):
+ """
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
+ param = underlying_param * log_scale.exp())
+
+
+ Args:
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
+ Unlike common optimizers, which accept model.parameters() or groups of parameters(),
+ this optimizer could accept model.named_parameters() or groups of named_parameters().
+ See comments of function _get_names_of_parameters for its 4 possible cases.
+ lr: The learning rate. We will typically use a learning rate schedule that starts
+ at 0.03 and decreases over time, i.e. much higher than other common
+ optimizers.
+ clipping_scale: (e.g. 2.0)
+ A scale for gradient-clipping: if specified, the normalized gradients
+ over the whole model will be clipped to have 2-norm equal to
+ `clipping_scale` times the median 2-norm over the most recent period
+ of `clipping_update_period` minibatches. By "normalized gradients",
+ we mean after multiplying by the rms parameter value for this tensor
+ [for non-scalars]; this is appropriate because our update is scaled
+ by this quantity.
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
+ Must satisfy 0 < beta <= beta2 < 1.
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
+ scale of each parameter tensor and scalar parameters of the mode..
+ If each parameter were decomposed
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
+ would be a the scaling factor on the learning rate of p_scale.
+ eps: A general-purpose epsilon to prevent division by zero
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be >= this value)
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be <= this value)
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
+ model has any parameters with numel() == 1).
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
+ of the parameter tensor. This is provided to save a little time
+ in the update.
+ clipping_update_period: if clipping_scale is specified, this is the period
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=3e-02,
+ clipping_scale=None,
+ betas=(0.9, 0.98),
+ scalar_lr_scale=0.1,
+ eps=1.0e-08,
+ param_min_rms=1.0e-05,
+ param_max_rms=3.0,
+ scalar_max=10.0,
+ size_update_period=4,
+ clipping_update_period=100,
+ ):
+
+ defaults = dict(
+ lr=lr,
+ clipping_scale=clipping_scale,
+ betas=betas,
+ scalar_lr_scale=scalar_lr_scale,
+ eps=eps,
+ param_min_rms=param_min_rms,
+ param_max_rms=param_max_rms,
+ scalar_max=scalar_max,
+ size_update_period=size_update_period,
+ clipping_update_period=clipping_update_period,
+ )
+
+ # If params only contains parameters or group of parameters,
+ # i.e when parameter names are not given,
+ # this flag will be set to False in funciton _get_names_of_parameters.
+ self.show_dominant_parameters = True
+ param_groups, parameters_names = self._get_names_of_parameters(params)
+ super(ScaledAdam, self).__init__(param_groups, defaults)
+ assert len(self.param_groups) == len(parameters_names)
+ self.parameters_names = parameters_names
+
+ def _get_names_of_parameters(
+ self, params_or_named_params
+ ) -> Tuple[List[Dict], List[List[str]]]:
+ """
+ Args:
+ params_or_named_params: according to the way ScaledAdam is initialized in train.py,
+ this argument could be one of following 4 cases,
+ case 1, a generator of parameter, e.g.:
+ optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
+
+ case 2, a list of parameter groups with different config, e.g.:
+ model_param_groups = [
+ {'params': model.encoder.parameters(), 'lr': 0.05},
+ {'params': model.decoder.parameters(), 'lr': 0.01},
+ {'params': model.joiner.parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
+
+ case 3, a generator of named_parameter, e.g.:
+ optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
+
+ case 4, a list of named_parameter groups with different config, e.g.:
+ model_named_param_groups = [
+ {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
+ {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
+ {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
+
+ For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
+ For case 3 and case 4, firstly, names and params are extracted from input named_params,
+ then, these extracted params are used to initialize the underlying torch.optimizer,
+ and these extracted names are mainly used by function
+ `_show_gradient_dominating_parameter`
+
+ Returns:
+ Returns a tuple containing 2 elements:
+ - `param_groups` with type List[Dict], each Dict element is a parameter group.
+ An example of `param_groups` could be:
+ [
+ {'params': `one iterable of Parameter`, 'lr': 0.05},
+ {'params': `another iterable of Parameter`, 'lr': 0.08},
+ {'params': `a third iterable of Parameter`, 'lr': 0.1},
+ ]
+ - `param_gruops_names` with type List[List[str]],
+ each `List[str]` is for a group['params'] in param_groups,
+ and each `str` is the name of a parameter.
+ A dummy name "foo" is related to each parameter,
+ if input are params without names, i.e. case 1 or case 2.
+ """
+ # variable naming convention in this function:
+ # p is short for param.
+ # np is short for named_param.
+ # p_or_np is short for param_or_named_param.
+ # cur is short for current.
+ # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
+ # groups is a List[group]
+
+ iterable_or_groups = list(params_or_named_params)
+ if len(iterable_or_groups) == 0:
+ raise ValueError("optimizer got an empty parameter list")
+
+ # The first value of returned tuple. A list of dicts containing at
+ # least 'params' as a key.
+ param_groups = []
+
+ # The second value of returned tuple,
+ # a List[List[str]], each sub-List is for a group.
+ param_groups_names = []
+
+ if not isinstance(iterable_or_groups[0], dict):
+ # case 1 or case 3,
+ # the input is an iterable of parameter or named parameter.
+ param_iterable_cur_group = []
+ param_names_cur_group = []
+ for p_or_np in iterable_or_groups:
+ if isinstance(p_or_np, tuple):
+ # case 3
+ name, param = p_or_np
+ else:
+ # case 1
+ assert isinstance(p_or_np, torch.Tensor)
+ param = p_or_np
+ # Assign a dummy name as a placeholder
+ name = "foo"
+ self.show_dominant_parameters = False
+ param_iterable_cur_group.append(param)
+ param_names_cur_group.append(name)
+ param_groups.append({"params": param_iterable_cur_group})
+ param_groups_names.append(param_names_cur_group)
+ else:
+ # case 2 or case 4
+ # the input is groups of parameter or named parameter.
+ for cur_group in iterable_or_groups:
+ assert "named_params" in cur_group
+ name_list = [ x[0] for x in cur_group["named_params"] ]
+ p_list = [ x[1] for x in cur_group["named_params"] ]
+ del cur_group["named_params"]
+ cur_group["params"] = p_list
+ param_groups.append(cur_group)
+ param_groups_names.append(name_list)
+
+ return param_groups, param_groups_names
+
+ def __setstate__(self, state):
+ super(ScaledAdam, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ batch = True
+
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
+
+ with self.batched_params(group["params"], group_params_names) as batches:
+
+ # batches is list of pairs (stacked_param, state). stacked_param is like
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
+ # a stacking dim, it is not a real dim.
+
+ if (
+ len(batches[0][1]) == 0
+ ): # if len(first state) == 0: not yet initialized
+ clipping_scale = 1
+ else:
+ clipping_scale = self._get_clipping_scale(group, batches)
+
+ for p, state, _ in batches:
+ # Perform optimization step.
+ # grad is not going to be None, we handled that when creating the batches.
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ # State initialization
+ if len(state) == 0:
+ self._init_state(group, p, state)
+
+ self._step_one_batch(group, p, state, clipping_scale)
+
+ return loss
+
+ def _init_state(self, group: dict, p: Tensor, state: dict):
+ """
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
+ is actually the batch dimension, corresponding to batched-together
+ parameters of a given shape.
+
+
+ Args:
+ group: Dict to look up configuration values.
+ p: The parameter that we are initializing the state for
+ state: Dict from string to whatever state we are initializing
+ """
+ size_update_period = group["size_update_period"]
+
+ state["step"] = 0
+
+ kwargs = {"device": p.device, "dtype": p.dtype}
+
+ # 'delta' implements conventional momentum. There are
+ # several different kinds of update going on, so rather than
+ # compute "exp_avg" like in Adam, we store and decay a
+ # parameter-change "delta", which combines all forms of
+ # update. this is equivalent to how it's done in Adam,
+ # except for the first few steps.
+ state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
+
+ if numel > 1:
+ # "param_rms" just periodically records the scalar root-mean-square value of
+ # the parameter tensor.
+ # it has a shape like (batch_size, 1, 1, 1, 1)
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ state["param_rms"] = param_rms
+
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
+ state["scale_grads"] = torch.zeros(
+ size_update_period, *param_rms.shape, **kwargs
+ )
+
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ def _get_clipping_scale(
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
+ ) -> float:
+ """
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
+ by this amount before applying the rest of the update.
+
+ Args:
+ group: the parameter group, an item in self.param_groups
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ assert len(tuples) >= 1
+ clipping_scale = group["clipping_scale"]
+ (first_p, first_state, _) = tuples[0]
+ step = first_state["step"]
+ if clipping_scale is None or step == 0:
+ # no clipping. return early on step == 0 because the other
+ # parameters' state won't have been initialized yet.
+ return 1.0
+ clipping_update_period = group["clipping_update_period"]
+
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
+ for (p, state, param_names) in tuples:
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ if p.numel() == p.shape[0]: # a batch of scalars
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
+ else:
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
+
+ tot_norm = tot_sumsq.sqrt()
+ if "model_norms" not in first_state:
+ first_state["model_norms"] = torch.zeros(
+ clipping_update_period, device=p.device
+ )
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
+
+ if step % clipping_update_period == 0:
+ # Print some stats.
+ # We don't reach here if step == 0 because we would have returned
+ # above.
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
+ quartiles = []
+ for n in range(0, 5):
+ index = min(
+ clipping_update_period - 1, (clipping_update_period // 4) * n
+ )
+ quartiles.append(sorted_norms[index].item())
+
+ median = quartiles[2]
+ threshold = clipping_scale * median
+ first_state["model_norm_threshold"] = threshold
+ percent_clipped = (
+ first_state["num_clipped"] * 100.0 / clipping_update_period
+ if "num_clipped" in first_state
+ else 0.0
+ )
+ first_state["num_clipped"] = 0
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
+ logging.info(
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
+ )
+
+ if step < clipping_update_period:
+ return 1.0 # We have not yet estimated a norm to clip to.
+ else:
+ try:
+ model_norm_threshold = first_state["model_norm_threshold"]
+ except KeyError:
+ logging.info(
+ "Warning: model_norm_threshold not in state: possibly "
+ "you changed config when restarting, adding clipping_scale option?"
+ )
+ return 1.0
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
+ if ans < 1.0:
+ first_state["num_clipped"] += 1
+ if ans < 0.1:
+ logging.warn(
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
+ )
+ if self.show_dominant_parameters:
+ assert p.shape[0] == len(param_names)
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
+ return ans
+
+ def _show_gradient_dominating_parameter(
+ self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
+ ):
+ """
+ Show information of parameter which dominates tot_sumsq.
+
+ Args:
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
+ from tuples, we still pass it to save some time.
+ """
+ all_sumsq_orig = {}
+ for (p, state, batch_param_names) in tuples:
+ # p is a stacked batch parameters.
+ batch_grad = p.grad
+ if p.numel() == p.shape[0]: # a batch of scalars
+ batch_sumsq_orig = batch_grad**2
+ # Dummy values used by following `zip` statement.
+ batch_rms_orig = torch.ones(p.shape[0])
+ else:
+ batch_rms_orig = state["param_rms"]
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
+ dim=list(range(1, batch_grad.ndim))
+ )
+
+ for name, sumsq_orig, rms, grad in zip(
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+ ):
+
+ proportion_orig = sumsq_orig / tot_sumsq
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
+
+ assert torch.isclose(
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
+ torch.tensor(1.0),
+ )
+ sorted_by_proportion = {
+ k: v
+ for k, v in sorted(
+ all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
+ )
+ }
+ dominant_param_name = next(iter(sorted_by_proportion))
+ (
+ dominant_proportion,
+ dominant_sumsq,
+ dominant_rms,
+ dominant_grad,
+ ) = sorted_by_proportion[dominant_param_name]
+ logging.info(
+ f"Parameter dominating tot_sumsq {dominant_param_name}"
+ f" with proportion {dominant_proportion:.2f},"
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+ f"={dominant_sumsq:.3e},"
+ f" grad_sumsq={(dominant_grad**2).sum():.3e},"
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+ )
+
+ def _step_one_batch(
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
+ ):
+ """
+ Do the step for one parameter, which is actually going to be a batch of
+ `real` parameters, with dim 0 as the batch dim.
+ Args:
+ group: dict to look up configuration values
+ p: parameter to update (actually multiple parameters stacked together
+ as a batch)
+ state: state-dict for p, to look up the optimizer state
+ """
+ lr = group["lr"]
+ size_update_period = group["size_update_period"]
+ beta1 = group["betas"][0]
+
+ grad = p.grad
+ if clipping_scale != 1.0:
+ grad = grad * clipping_scale
+ step = state["step"]
+ delta = state["delta"]
+
+ delta.mul_(beta1)
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
+ if numel > 1:
+ # Update the size/scale of p, and set param_rms
+ scale_grads = state["scale_grads"]
+ scale_grads[step % size_update_period] = (p * grad).sum(
+ dim=list(range(1, p.ndim)), keepdim=True
+ )
+ if step % size_update_period == size_update_period - 1:
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
+ param_rms.copy_(
+ (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ )
+ if step > 0:
+ # self._size_update() learns the overall scale on the
+ # parameter, by shrinking or expanding it.
+ self._size_update(group, scale_grads, p, state)
+
+ if numel == 1:
+ # For parameters with 1 element we just use regular Adam.
+ # Updates delta.
+ self._step_scalar(group, p, state)
+ else:
+ self._step(group, p, state)
+
+ state["step"] = step + 1
+
+ def _size_update(
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
+ ) -> None:
+ """
+ Called only where p.numel() > 1, this updates the scale of the parameter.
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
+ gradient descent on underlying param and on scale, this function does the update
+ on `scale`.
+
+ Args:
+ group: dict to look up configuration values
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
+ grads w.r.t. the scales.
+ p: The parameter to update
+ state: The state-dict of p
+ """
+
+ param_rms = state["param_rms"]
+ beta1, beta2 = group["betas"]
+ size_lr = group["lr"] * group["scalar_lr_scale"]
+ param_min_rms = group["param_min_rms"]
+ param_max_rms = group["param_max_rms"]
+ eps = group["eps"]
+ step = state["step"]
+ batch_size = p.shape[0]
+
+ size_update_period = scale_grads.shape[0]
+ # correct beta2 for the size update period: we will have
+ # faster decay at this level.
+ beta2_corr = beta2**size_update_period
+
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
+ alpha=1 - beta2_corr,
+ ) # shape is (batch_size, 1, 1, ...)
+
+ # The 1st time we reach here is when size_step == 1.
+ size_step = (step + 1) // size_update_period
+ bias_correction2 = 1 - beta2_corr**size_step
+ # we don't bother with bias_correction1; this will help prevent divergence
+ # at the start of training.
+
+ denom = scale_exp_avg_sq.sqrt() + eps
+
+ scale_step = (
+ -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
+ )
+
+ is_too_small = param_rms < param_min_rms
+
+ # when the param gets too small, just don't shrink it any further.
+ scale_step.masked_fill_(is_too_small, 0.0)
+
+ # and ensure the parameter rms after update never exceeds param_max_rms.
+ # We have to look at the trained model for parameters at or around the
+ # param_max_rms, because sometimes they can indicate a problem with the
+ # topology or settings.
+ scale_step = torch.minimum(scale_step,
+ (param_max_rms - param_rms) / param_rms)
+
+ delta = state["delta"]
+ # the factor of (1-beta1) relates to momentum.
+ delta.add_(p * scale_step, alpha=(1 - beta1))
+
+ def _step(self, group: dict, p: Tensor, state: dict):
+ """
+ This function does the core update of self.step(), in the case where the members of
+ the batch have more than 1 element.
+
+ Args:
+ group: A dict which will be used to look up configuration values
+ p: The parameter to be updated
+ grad: The grad of p
+ state: The state-dict corresponding to parameter p
+
+ This function modifies p.
+ """
+ grad = p.grad
+ lr = group["lr"]
+ beta1, beta2 = group["betas"]
+ eps = group["eps"]
+ param_min_rms = group["param_min_rms"]
+ step = state["step"]
+
+ exp_avg_sq = state["exp_avg_sq"]
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
+
+ this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
+ if bias_correction2 < 0.99:
+ # note: not in-place.
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
+
+ denom = exp_avg_sq.sqrt()
+ denom += eps
+ grad = grad / denom
+
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
+
+ delta = state["delta"]
+ delta.add_(grad * alpha)
+ p.add_(delta)
+
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
+ """
+ A simplified form of the core update for scalar tensors, where we cannot get a good
+ estimate of the parameter rms.
+ """
+ beta1, beta2 = group["betas"]
+ scalar_max = group["scalar_max"]
+ eps = group["eps"]
+ lr = group["lr"] * group["scalar_lr_scale"]
+ grad = p.grad
+
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
+ # slower update at the start will help stability anyway.
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
+
+ delta = state["delta"]
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
+ p.clamp_(min=-scalar_max, max=scalar_max)
+ p.add_(delta)
+
+
+class LRScheduler(object):
+ """
+ Base-class for learning rate schedulers where the learning-rate depends on both the
+ batch and the epoch.
+ """
+
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
+ # Attach optimizer
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
+ self.optimizer = optimizer
+ self.verbose = verbose
+
+ for group in optimizer.param_groups:
+ group.setdefault("base_lr", group["lr"])
+
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
+
+ self.epoch = 0
+ self.batch = 0
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {
+ "base_lrs": self.base_lrs,
+ "epoch": self.epoch,
+ "batch": self.batch,
+ }
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def get_last_lr(self) -> List[float]:
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
+ return self._last_lr
+
+ def get_lr(self):
+ # Compute list of learning rates from self.epoch and self.batch and
+ # self.base_lrs; this must be overloaded by the user.
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
+ raise NotImplementedError
+
+ def step_batch(self, batch: Optional[int] = None) -> None:
+ # Step the batch index, or just set it. If `batch` is specified, it
+ # must be the batch index from the start of training, i.e. summed over
+ # all epochs.
+ # You can call this in any order; if you don't provide 'batch', it should
+ # of course be called once per batch.
+ if batch is not None:
+ self.batch = batch
+ else:
+ self.batch = self.batch + 1
+ self._set_lrs()
+
+ def step_epoch(self, epoch: Optional[int] = None):
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
+ # arg, you should call it at the end of the epoch.
+ if epoch is not None:
+ self.epoch = epoch
+ else:
+ self.epoch = self.epoch + 1
+ self._set_lrs()
+
+ def _set_lrs(self):
+ values = self.get_lr()
+ assert len(values) == len(self.optimizer.param_groups)
+
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
+ param_group, lr = data
+ param_group["lr"] = lr
+ self.print_lr(self.verbose, i, lr)
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
+
+ def print_lr(self, is_verbose, group, lr):
+ """Display the current learning rate."""
+ if is_verbose:
+ logging.info(
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
+ f" of group {group} to {lr:.4e}."
+ )
+
+
+class Eden(LRScheduler):
+ """
+ Eden scheduler.
+ The basic formula (before warmup) is:
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ lr_epochs: the number of epochs after which we start significantly
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
+ 20 to 40 epochs, but may need smaller number if dataset is huge
+ and you will do few epochs.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ lr_epochs: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ warmup_start: float = 0.5,
+ verbose: bool = False,
+ ):
+ super(Eden, self).__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.lr_epochs = lr_epochs
+ self.warmup_batches = warmup_batches
+
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
+ self.warmup_start = warmup_start
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.25 * (
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
+ )
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]
+
+
+def _test_eden():
+ m = torch.nn.Linear(100, 100)
+ optim = ScaledAdam(m.parameters(), lr=0.03)
+
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
+
+ for epoch in range(10):
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
+
+ for step in range(20):
+ x = torch.randn(200, 100).detach()
+ x.requires_grad = True
+ y = m(x)
+ dy = torch.randn(200, 100).detach()
+ f = (y * dy).sum()
+ f.backward()
+
+ optim.step()
+ scheduler.step_batch()
+ optim.zero_grad()
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ logging.info(f"state dict = {scheduler.state_dict()}")
+
+
+# This is included mostly as a baseline for ScaledAdam.
+class Eve(Optimizer):
+ """
+ Implements Eve algorithm. This is a modified version of AdamW with a special
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
+ will be close to invariant to the absolute scale on the parameter matrix.
+
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+ Eve is unpublished so far.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
+ this value means that the weight would decay significantly after
+ about 3k minibatches. Is not multiplied by learning rate, but
+ is conditional on RMS-value of parameter being > target_rms.
+ target_rms (float, optional): target root-mean-square value of
+ parameters, if they fall below this we will stop applying weight decay.
+
+
+ .. _Adam: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.98),
+ eps=1e-8,
+ weight_decay=1e-3,
+ target_rms=0.1,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0 <= weight_decay <= 0.1:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0 < target_rms <= 10.0:
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ target_rms=target_rms,
+ )
+ super(Eve, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Eve, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError("AdamW does not support sparse gradients")
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
+ group["eps"]
+ )
+
+ step_size = group["lr"] / bias_correction1
+ target_rms = group["target_rms"]
+ weight_decay = group["weight_decay"]
+
+ if p.numel() > 1:
+ # avoid applying this weight-decay on "scaling factors"
+ # (which are scalar).
+ is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
+ p.mul_(1 - (weight_decay * is_above_target_rms))
+
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ if random.random() < 0.0005:
+ step = (exp_avg / denom) * step_size
+ logging.info(
+ f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
+ )
+
+ return loss
+
+
+def _test_scaled_adam(hidden_dim: int):
+ import timeit
+
+ from scaling import ScaledLinear
+
+ E = 100
+ B = 4
+ T = 2
+ logging.info("in test_eve_cain")
+ # device = torch.device('cuda')
+ device = torch.device("cpu")
+ dtype = torch.float32
+
+ fix_random_seed(42)
+ # these input_magnitudes and output_magnitudes are to test that
+ # Abel is working as we expect and is able to adjust scales of
+ # different dims differently.
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+
+ for iter in [1, 0]:
+ fix_random_seed(42)
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
+
+ m = torch.nn.Sequential(
+ Linear(E, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, E),
+ ).to(device)
+
+ train_pairs = [
+ (
+ 100.0
+ * torch.randn(B, T, E, device=device, dtype=dtype)
+ * input_magnitudes,
+ torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
+ )
+ for _ in range(20)
+ ]
+
+ if iter == 0:
+ optim = Eve(m.parameters(), lr=0.003)
+ elif iter == 1:
+ optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
+
+ start = timeit.default_timer()
+ avg_loss = 0.0
+ for epoch in range(180):
+ scheduler.step_epoch()
+ # if epoch == 100 and iter in [2,3]:
+ # optim.reset_speedup() # check it doesn't crash.
+
+ # if epoch == 130:
+ # opts = diagnostics.TensorDiagnosticOptions(
+ # 2 ** 22
+ # ) # allow 4 megabytes per sub-module
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
+
+ for n, (x, y) in enumerate(train_pairs):
+ y_out = m(x)
+ loss = ((y_out - y) ** 2).mean() * 100.0
+ if epoch == 0 and n == 0:
+ avg_loss = loss.item()
+ else:
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
+ if n == 0 and epoch % 5 == 0:
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
+ lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
+ loss.log().backward()
+ optim.step()
+ optim.zero_grad()
+ scheduler.step_batch()
+
+ # diagnostic.print_diagnostics()
+
+ stop = timeit.default_timer()
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ # logging.info("state dict = ", scheduler.state_dict())
+ # logging.info("optim state_dict = ", optim.state_dict())
+ logging.info(f"input_magnitudes = {input_magnitudes}")
+ logging.info(f"output_magnitudes = {output_magnitudes}")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ logging.getLogger().setLevel(logging.INFO)
+ import subprocess
+
+ s = subprocess.check_output(
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
+ )
+ logging.info(s)
+ import sys
+
+ if len(sys.argv) > 1:
+ hidden_dim = int(sys.argv[1])
+ else:
+ hidden_dim = 200
+
+ _test_scaled_adam(hidden_dim)
+ _test_eden()
diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py
index b398ddb81..95af2f056 100644
--- a/egs/aishell/ASR/whisper/train.py
+++ b/egs/aishell/ASR/whisper/train.py
@@ -88,7 +88,7 @@ from icefall.utils import (
)
import whisper
-
+from model import load_model
from label_smoothing import LabelSmoothingLoss
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -631,7 +631,7 @@ def compute_loss(
feature = feature.to(device)
feature = feature.transpose(1, 2) # (N, C, T)
# pad feature from B,80,T to B,80,3000
- feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
+ # feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
@@ -923,7 +923,8 @@ def run(rank, world_size, args):
logging.info("About to create model")
- model = whisper.load_model("medium")
+ #model = whisper.load_model("medium")
+ model = load_model("medium")
del model.alignment_heads
params.tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, language="zh", task="transcribe"