From 5bf3a9cfe02af0f092ea3d48d65ba20e1df2a46f Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Tue, 26 Sep 2023 17:09:04 +0800 Subject: [PATCH] using audio with any length --- egs/aishell/ASR/whisper/model.py | 433 +++++++++++ egs/aishell/ASR/whisper/optim.py | 1173 ++++++++++++++++++++++++++++++ egs/aishell/ASR/whisper/train.py | 7 +- 3 files changed, 1610 insertions(+), 3 deletions(-) create mode 100644 egs/aishell/ASR/whisper/model.py create mode 100644 egs/aishell/ASR/whisper/optim.py diff --git a/egs/aishell/ASR/whisper/model.py b/egs/aishell/ASR/whisper/model.py new file mode 100644 index 000000000..953b80ff4 --- /dev/null +++ b/egs/aishell/ASR/whisper/model.py @@ -0,0 +1,433 @@ +import torch +import torch.nn as nn +import base64 +import gzip +from dataclasses import dataclass +from typing import Dict, Iterable, Optional, Union +import os +import urllib +import hashlib +import numpy as np + +import torch.nn.functional as F +from torch import Tensor + +from whisper.decoding import decode as decode_function +from whisper.transcribe import transcribe as transcribe_function + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None + ): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = ( + MultiHeadAttention(n_state, n_head) if cross_attention else None + ) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + ) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__( + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + + x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) + #x = (x + self.positional_embedding).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + def __init__( + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ + ResidualAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ] + ) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = ( + self.token_embedding(x) + + self.positional_embedding[offset : offset + x.shape[-1]] + ) + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits + + +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # use the last half layers for alignment by default; see `set_alignment_heads()` below + all_heads = torch.zeros( + self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + + def set_alignment_heads(self, dump: bytes): + array = np.frombuffer( + gzip.decompress(base64.b85decode(dump)), dtype=bool + ).copy() + mask = torch.from_numpy(array).reshape( + self.dims.n_text_layer, self.dims.n_text_head + ) + self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) + + def embed_audio(self, mel: torch.Tensor): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + return self.decoder(tokens, audio_features) + + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor + ) -> Dict[str, torch.Tensor]: + return self.decoder(tokens, self.encoder(mel)) + + @property + def device(self): + return next(self.parameters()).device + + @property + def is_multilingual(self): + return self.dims.n_vocab == 51865 + + def install_kv_cache_hooks(self, cache: Optional[dict] = None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value + tensors calculated for the previous positions. This method returns a dictionary that stores + all caches, and the necessary hooks for the key and value projection modules that save the + intermediate tensors to be reused during later calculations. + + Returns + ------- + cache : Dict[nn.Module, torch.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[1] > self.dims.n_text_ctx: + # save as-is, for the first token or cross attention + cache[module] = output + else: + cache[module] = torch.cat([cache[module], output], dim=1).detach() + return cache[module] + + def install_hooks(layer: nn.Module): + if isinstance(layer, MultiHeadAttention): + hooks.append(layer.key.register_forward_hook(save_to_cache)) + hooks.append(layer.value.register_forward_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks + + #detect_language = detect_language_function + transcribe = transcribe_function + decode = decode_function + +_MODELS = { + "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", + "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", + "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", + "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", + "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", + "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", + "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", + "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", + "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", + "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", + "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", +} + +def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: + os.makedirs(root, exist_ok=True) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, os.path.basename(url)) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + with open(download_target, "rb") as f: + model_bytes = f.read() + if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: + return model_bytes if in_memory else download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + model_bytes = open(download_target, "rb").read() + if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) + + return model_bytes if in_memory else download_target + +def load_model( + name: str, + device: Optional[Union[str, torch.device]] = None, + download_root: str = None, + in_memory: bool = False, +) -> Whisper: + """ + Load a Whisper ASR model + + Parameters + ---------- + name : str + one of the official model names listed by `whisper.available_models()`, or + path to a model checkpoint containing the model dimensions and the model state_dict. + device : Union[str, torch.device] + the PyTorch device to put the model into + download_root: str + path to download the model files; by default, it uses "~/.cache/whisper" + in_memory: bool + whether to preload the model weights into host memory + + Returns + ------- + model : Whisper + The Whisper ASR model instance + """ + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if download_root is None: + default = os.path.join(os.path.expanduser("~"), ".cache") + download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") + + if name in _MODELS: + checkpoint_file = _download(_MODELS[name], download_root, in_memory) + # alignment_heads = _ALIGNMENT_HEADS[name] + alignment_heads = None + elif os.path.isfile(name): + checkpoint_file = open(name, "rb").read() if in_memory else name + alignment_heads = None + else: + raise RuntimeError( + f"Model {name} not found; available models = {available_models()}" + ) + + with ( + io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") + ) as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + # if alignment_heads is not None: + # model.set_alignment_heads(alignment_heads) + + return model.to(device) + + diff --git a/egs/aishell/ASR/whisper/optim.py b/egs/aishell/ASR/whisper/optim.py new file mode 100644 index 000000000..abfb2092c --- /dev/null +++ b/egs/aishell/ASR/whisper/optim.py @@ -0,0 +1,1173 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + assert "named_params" in cur_group + name_list = [ x[0] for x in cur_group["named_params"] ] + p_list = [ x[1] for x in cur_group["named_params"] ] + del cur_group["named_params"] + cur_group["params"] = p_list + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, (clipping_update_period // 4) * n + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, + (param_max_rms - param_rms) / param_rms) + + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index b398ddb81..95af2f056 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -88,7 +88,7 @@ from icefall.utils import ( ) import whisper - +from model import load_model from label_smoothing import LabelSmoothingLoss LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -631,7 +631,7 @@ def compute_loss( feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) # pad feature from B,80,T to B,80,3000 - feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) + # feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -923,7 +923,8 @@ def run(rank, world_size, args): logging.info("About to create model") - model = whisper.load_model("medium") + #model = whisper.load_model("medium") + model = load_model("medium") del model.alignment_heads params.tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual, language="zh", task="transcribe"