diff --git a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json
deleted file mode 100644
index bf8cc0452..000000000
--- a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "fp16": {
- "enabled": true,
- "loss_scale": 0,
- "loss_scale_window": 100,
- "initial_scale_power": 16,
- "hysteresis": 2,
- "min_loss_scale": 0.01
- },
- "zero_optimization": {
- "stage": 1,
- "allgather_partitions": true,
- "allgather_bucket_size": 2e8,
- "overlap_comm": true,
- "reduce_scatter": true,
- "reduce_bucket_size": 2e8,
- "contiguous_gradients": true
- },
- "optimizer": {
- "type": "Adam",
- "params": {
- "lr": 1e-5
- }
- },
- "scheduler": {
- "type": "WarmupLR",
- "params": {
- "warmup_min_lr": 0,
- "warmup_max_lr": 1e-5,
- "warmup_num_steps": 100
- }
- },
- "gradient_accumulation_steps": 1,
- "gradient_clipping": 5,
- "steps_per_print": 50,
- "train_micro_batch_size_per_gpu": 1,
- "wall_clock_breakdown": false
-}
diff --git a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json
new file mode 120000
index 000000000..af7162d6c
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json
@@ -0,0 +1 @@
+../../../aishell/ASR/whisper/ds_config_zero1.json
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/whisper/label_smoothing.py b/egs/wenetspeech/ASR/whisper/label_smoothing.py
deleted file mode 100644
index 52d2eda3b..000000000
--- a/egs/wenetspeech/ASR/whisper/label_smoothing.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-
-
-class LabelSmoothingLoss(torch.nn.Module):
- """
- Implement the LabelSmoothingLoss proposed in the following paper
- https://arxiv.org/pdf/1512.00567.pdf
- (Rethinking the Inception Architecture for Computer Vision)
-
- """
-
- def __init__(
- self,
- ignore_index: int = -1,
- label_smoothing: float = 0.1,
- reduction: str = "sum",
- ) -> None:
- """
- Args:
- ignore_index:
- ignored class id
- label_smoothing:
- smoothing rate (0.0 means the conventional cross entropy loss)
- reduction:
- It has the same meaning as the reduction in
- `torch.nn.CrossEntropyLoss`. It can be one of the following three
- values: (1) "none": No reduction will be applied. (2) "mean": the
- mean of the output is taken. (3) "sum": the output will be summed.
- """
- super().__init__()
- assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
- assert reduction in ("none", "sum", "mean"), reduction
- self.ignore_index = ignore_index
- self.label_smoothing = label_smoothing
- self.reduction = reduction
-
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- """
- Compute loss between x and target.
-
- Args:
- x:
- prediction of dimension
- (batch_size, input_length, number_of_classes).
- target:
- target masked with self.ignore_index of
- dimension (batch_size, input_length).
-
- Returns:
- A scalar tensor containing the loss without normalization.
- """
- assert x.ndim == 3
- assert target.ndim == 2
- assert x.shape[:2] == target.shape
- num_classes = x.size(-1)
- x = x.reshape(-1, num_classes)
- # Now x is of shape (N*T, C)
-
- # We don't want to change target in-place below,
- # so we make a copy of it here
- target = target.clone().reshape(-1)
-
- ignored = target == self.ignore_index
-
- # See https://github.com/k2-fsa/icefall/issues/240
- # and https://github.com/k2-fsa/icefall/issues/297
- # for why we don't use target[ignored] = 0 here
- target = torch.where(ignored, torch.zeros_like(target), target)
-
- true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)
-
- true_dist = (
- true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
- )
-
- # Set the value of ignored indexes to 0
- #
- # See https://github.com/k2-fsa/icefall/issues/240
- # and https://github.com/k2-fsa/icefall/issues/297
- # for why we don't use true_dist[ignored] = 0 here
- true_dist = torch.where(
- ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
- torch.zeros_like(true_dist),
- true_dist,
- )
-
- loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
- if self.reduction == "sum":
- return loss.sum()
- elif self.reduction == "mean":
- return loss.sum() / (~ignored).sum()
- else:
- return loss.sum(dim=-1)
diff --git a/egs/wenetspeech/ASR/whisper/label_smoothing.py b/egs/wenetspeech/ASR/whisper/label_smoothing.py
new file mode 120000
index 000000000..e9d239fff
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/label_smoothing.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/conformer_ctc/label_smoothing.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/whisper/model.py b/egs/wenetspeech/ASR/whisper/model.py
deleted file mode 100755
index 9ec412513..000000000
--- a/egs/wenetspeech/ASR/whisper/model.py
+++ /dev/null
@@ -1,431 +0,0 @@
-import torch
-import torch.nn as nn
-import base64
-import gzip
-import warnings
-from tqdm import tqdm
-from dataclasses import dataclass
-from typing import Dict, Iterable, Optional, Union
-import os
-import urllib
-import hashlib
-import numpy as np
-
-import torch.nn.functional as F
-from torch import Tensor
-
-from whisper.decoding import decode as decode_function
-from whisper.transcribe import transcribe as transcribe_function
-
-
-@dataclass
-class ModelDimensions:
- n_mels: int
- n_audio_ctx: int
- n_audio_state: int
- n_audio_head: int
- n_audio_layer: int
- n_vocab: int
- n_text_ctx: int
- n_text_state: int
- n_text_head: int
- n_text_layer: int
-
-
-class LayerNorm(nn.LayerNorm):
- def forward(self, x: Tensor) -> Tensor:
- return super().forward(x.float()).type(x.dtype)
-
-
-class Linear(nn.Linear):
- def forward(self, x: Tensor) -> Tensor:
- return F.linear(
- x,
- self.weight.to(x.dtype),
- None if self.bias is None else self.bias.to(x.dtype),
- )
-
-
-class Conv1d(nn.Conv1d):
- def _conv_forward(
- self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
- ) -> Tensor:
- return super()._conv_forward(
- x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
- )
-
-
-def sinusoids(length, channels, max_timescale=10000):
- """Returns sinusoids for positional embedding"""
- assert channels % 2 == 0
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
-
-
-class MultiHeadAttention(nn.Module):
- def __init__(self, n_state: int, n_head: int):
- super().__init__()
- self.n_head = n_head
- self.query = Linear(n_state, n_state)
- self.key = Linear(n_state, n_state, bias=False)
- self.value = Linear(n_state, n_state)
- self.out = Linear(n_state, n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- ):
- q = self.query(x)
-
- if kv_cache is None or xa is None or self.key not in kv_cache:
- # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
- # otherwise, perform key/value projections for self- or cross-attention as usual.
- k = self.key(x if xa is None else xa)
- v = self.value(x if xa is None else xa)
- else:
- # for cross-attention, calculate keys and values once and reuse in subsequent calls.
- k = kv_cache[self.key]
- v = kv_cache[self.value]
-
- wv, qk = self.qkv_attention(q, k, v, mask)
- return self.out(wv), qk
-
- def qkv_attention(
- self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
- ):
- n_batch, n_ctx, n_state = q.shape
- scale = (n_state // self.n_head) ** -0.25
- q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
- k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
- v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
-
- qk = q @ k
- if mask is not None:
- qk = qk + mask[:n_ctx, :n_ctx]
- qk = qk.float()
-
- w = F.softmax(qk, dim=-1).to(q.dtype)
- return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
-
-
-class ResidualAttentionBlock(nn.Module):
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
- super().__init__()
-
- self.attn = MultiHeadAttention(n_state, n_head)
- self.attn_ln = LayerNorm(n_state)
-
- self.cross_attn = (
- MultiHeadAttention(n_state, n_head) if cross_attention else None
- )
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
-
- n_mlp = n_state * 4
- self.mlp = nn.Sequential(
- Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
- )
- self.mlp_ln = LayerNorm(n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- ):
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
- if self.cross_attn:
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
- x = x + self.mlp(self.mlp_ln(x))
- return x
-
-
-class AudioEncoder(nn.Module):
- def __init__(
- self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
- ):
- super().__init__()
- self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
- self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
- self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
-
- self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
- [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
- )
- self.ln_post = LayerNorm(n_state)
-
- def forward(self, x: Tensor):
- """
- x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
- the mel spectrogram of the audio
- """
- x = F.gelu(self.conv1(x))
- x = F.gelu(self.conv2(x))
- x = x.permute(0, 2, 1)
-
- # change whisper to process audio with any length
- x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
-
- for block in self.blocks:
- x = block(x)
-
- x = self.ln_post(x)
- return x
-
-
-class TextDecoder(nn.Module):
- def __init__(
- self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
- ):
- super().__init__()
-
- self.token_embedding = nn.Embedding(n_vocab, n_state)
- self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
-
- self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
- [
- ResidualAttentionBlock(n_state, n_head, cross_attention=True)
- for _ in range(n_layer)
- ]
- )
- self.ln = LayerNorm(n_state)
-
- mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
- self.register_buffer("mask", mask, persistent=False)
-
- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
- """
- x : torch.LongTensor, shape = (batch_size, <= n_ctx)
- the text tokens
- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
- the encoded audio features to be attended on
- """
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
- x = (
- self.token_embedding(x)
- + self.positional_embedding[offset : offset + x.shape[-1]]
- )
- x = x.to(xa.dtype)
-
- for block in self.blocks:
- x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
-
- x = self.ln(x)
- logits = (
- x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
- ).float()
-
- return logits
-
-class Whisper(nn.Module):
- def __init__(self, dims: ModelDimensions):
- super().__init__()
- self.dims = dims
- self.encoder = AudioEncoder(
- self.dims.n_mels,
- self.dims.n_audio_ctx,
- self.dims.n_audio_state,
- self.dims.n_audio_head,
- self.dims.n_audio_layer,
- )
- self.decoder = TextDecoder(
- self.dims.n_vocab,
- self.dims.n_text_ctx,
- self.dims.n_text_state,
- self.dims.n_text_head,
- self.dims.n_text_layer,
- )
- # use the last half layers for alignment by default; see `set_alignment_heads()` below
- all_heads = torch.zeros(
- self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
- )
- all_heads[self.dims.n_text_layer // 2 :] = True
- self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
-
- def set_alignment_heads(self, dump: bytes):
- array = np.frombuffer(
- gzip.decompress(base64.b85decode(dump)), dtype=bool
- ).copy()
- mask = torch.from_numpy(array).reshape(
- self.dims.n_text_layer, self.dims.n_text_head
- )
- self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
-
- def embed_audio(self, mel: torch.Tensor):
- return self.encoder(mel)
-
- def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
- return self.decoder(tokens, audio_features)
-
- def forward(
- self, mel: torch.Tensor, tokens: torch.Tensor
- ) -> Dict[str, torch.Tensor]:
- return self.decoder(tokens, self.encoder(mel))
-
- @property
- def device(self):
- return next(self.parameters()).device
-
- @property
- def is_multilingual(self):
- return self.dims.n_vocab >= 51865
-
- @property
- def num_languages(self):
- return self.dims.n_vocab - 51765 - int(self.is_multilingual)
-
- def install_kv_cache_hooks(self, cache: Optional[dict] = None):
- """
- The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
- tensors calculated for the previous positions. This method returns a dictionary that stores
- all caches, and the necessary hooks for the key and value projection modules that save the
- intermediate tensors to be reused during later calculations.
-
- Returns
- -------
- cache : Dict[nn.Module, torch.Tensor]
- A dictionary object mapping the key/value projection modules to its cache
- hooks : List[RemovableHandle]
- List of PyTorch RemovableHandle objects to stop the hooks to be called
- """
- cache = {**cache} if cache is not None else {}
- hooks = []
-
- def save_to_cache(module, _, output):
- if module not in cache or output.shape[1] > self.dims.n_text_ctx:
- # save as-is, for the first token or cross attention
- cache[module] = output
- else:
- cache[module] = torch.cat([cache[module], output], dim=1).detach()
- return cache[module]
-
- def install_hooks(layer: nn.Module):
- if isinstance(layer, MultiHeadAttention):
- hooks.append(layer.key.register_forward_hook(save_to_cache))
- hooks.append(layer.value.register_forward_hook(save_to_cache))
-
- self.decoder.apply(install_hooks)
- return cache, hooks
-
- transcribe = transcribe_function
- decode = decode_function
-
-_MODELS = {
- "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
- "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
- "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
- "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
- "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
- "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
- "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
- "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
- "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
- "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
- "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
- "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
-}
-
-def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
- os.makedirs(root, exist_ok=True)
-
- expected_sha256 = url.split("/")[-2]
- download_target = os.path.join(root, os.path.basename(url))
-
- if os.path.exists(download_target) and not os.path.isfile(download_target):
- raise RuntimeError(f"{download_target} exists and is not a regular file")
-
- if os.path.isfile(download_target):
- with open(download_target, "rb") as f:
- model_bytes = f.read()
- if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
- return model_bytes if in_memory else download_target
- else:
- warnings.warn(
- f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
- )
-
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
- with tqdm(
- total=int(source.info().get("Content-Length")),
- ncols=80,
- unit="iB",
- unit_scale=True,
- unit_divisor=1024,
- ) as loop:
- while True:
- buffer = source.read(8192)
- if not buffer:
- break
-
- output.write(buffer)
- loop.update(len(buffer))
-
- model_bytes = open(download_target, "rb").read()
- if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
- raise RuntimeError(
- "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
- )
-
- return model_bytes if in_memory else download_target
-
-def load_model(
- name: str,
- device: Optional[Union[str, torch.device]] = 'cpu',
- download_root: str = None,
- in_memory: bool = False,
-) -> Whisper:
- """
- Load a Whisper ASR model
-
- Parameters
- ----------
- name : str
- one of the official model names listed by `whisper.available_models()`, or
- path to a model checkpoint containing the model dimensions and the model state_dict.
- device : Union[str, torch.device]
- the PyTorch device to put the model into
- download_root: str
- path to download the model files; by default, it uses "~/.cache/whisper"
- in_memory: bool
- whether to preload the model weights into host memory
-
- Returns
- -------
- model : Whisper
- The Whisper ASR model instance
- """
-
- # if device is None:
- # device = "cuda" if torch.cuda.is_available() else "cpu"
- if download_root is None:
- default = os.path.join(os.path.expanduser("~"), ".cache")
- download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
-
- if name in _MODELS:
- checkpoint_file = _download(_MODELS[name], download_root, in_memory)
- # alignment_heads = _ALIGNMENT_HEADS[name]
- alignment_heads = None
- elif os.path.isfile(name):
- checkpoint_file = open(name, "rb").read() if in_memory else name
- alignment_heads = None
- else:
- raise RuntimeError(
- f"Model {name} not found; available models = {available_models()}"
- )
-
- with (
- io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
- ) as fp:
- checkpoint = torch.load(fp, map_location=device)
- del checkpoint_file
-
- dims = ModelDimensions(**checkpoint["dims"])
- model = Whisper(dims)
- model.load_state_dict(checkpoint["model_state_dict"])
-
- return model.to(device)
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/whisper/optim.py b/egs/wenetspeech/ASR/whisper/optim.py
deleted file mode 100644
index 714d8db9a..000000000
--- a/egs/wenetspeech/ASR/whisper/optim.py
+++ /dev/null
@@ -1,1248 +0,0 @@
-# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
-#
-# See ../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import contextlib
-import logging
-import random
-from collections import defaultdict
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from lhotse.utils import fix_random_seed
-from torch import Tensor, nn
-from torch.optim import Optimizer
-
-
-class BatchedOptimizer(Optimizer):
- """
- This class adds to class Optimizer the capability to optimize parameters in batches:
- it will stack the parameters and their grads for you so the optimizer can work
- on tensors with an extra leading dimension. This is intended for speed with GPUs,
- as it reduces the number of kernels launched in the optimizer.
-
- Args:
- params:
- """
-
- def __init__(self, params, defaults):
- super(BatchedOptimizer, self).__init__(params, defaults)
-
- @contextlib.contextmanager
- def batched_params(self, param_group, group_params_names):
- """
- This function returns (technically, yields) a list of
- of tuples (p, state), where
- p is a `fake` parameter that is stacked (over axis 0) from real parameters
- that share the same shape, and its gradient is also stacked;
- `state` is the state corresponding to this batch of parameters
- (it will be physically located in the "state" for one of the real
- parameters, the last one that has any particular shape and dtype).
-
- This function is decorated as a context manager so that it can
- write parameters back to their "real" locations.
-
- The idea is, instead of doing:
-
- for p in group["params"]:
- state = self.state[p]
- ...
-
- you can do:
-
- with self.batched_params(group["params"]) as batches:
- for p, state, p_names in batches:
- ...
-
-
- Args:
- group: a parameter group, which is a list of parameters; should be
- one of self.param_groups.
- group_params_names: name for each parameter in group,
- which is List[str].
- """
- batches = defaultdict(
- list
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
- batches_names = defaultdict(
- list
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
-
- assert len(param_group) == len(group_params_names)
- for p, named_p in zip(param_group, group_params_names):
- key = (str(p.dtype), *p.shape)
- batches[key].append(p)
- batches_names[key].append(named_p)
-
- batches_names_keys = list(batches_names.keys())
- sorted_idx = sorted(
- range(len(batches_names)), key=lambda i: batches_names_keys[i]
- )
- batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
- batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
-
- stacked_params_dict = dict()
-
- # turn batches into a list, in deterministic order.
- # tuples will contain tuples of (stacked_param, state, stacked_params_names),
- # one for each batch in `batches`.
- tuples = []
-
- for batch, batch_names in zip(batches, batches_names):
- p = batch[0]
- # we arbitrarily store the state in the
- # state corresponding to the 1st parameter in the
- # group. class Optimizer will take care of saving/loading state.
- state = self.state[p]
- p_stacked = torch.stack(batch)
- grad = torch.stack(
- [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
- )
- p_stacked.grad = grad
- stacked_params_dict[key] = p_stacked
- tuples.append((p_stacked, state, batch_names))
-
- yield tuples # <-- calling code will do the actual optimization here!
-
- for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
- for i, p in enumerate(batch): # batch is list of Parameter
- p.copy_(stacked_params[i])
-
-
-class ScaledAdam(BatchedOptimizer):
- """
- Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
- proportional to the norm of that parameter; and also learn the scale of the parameter,
- in log space, subject to upper and lower limits (as if we had factored each parameter as
- param = underlying_param * log_scale.exp())
-
-
- Args:
- params: The parameters or param_groups to optimize (like other Optimizer subclasses)
- Unlike common optimizers, which accept model.parameters() or groups of parameters(),
- this optimizer could accept model.named_parameters() or groups of named_parameters().
- See comments of function _get_names_of_parameters for its 4 possible cases.
- lr: The learning rate. We will typically use a learning rate schedule that starts
- at 0.03 and decreases over time, i.e. much higher than other common
- optimizers.
- clipping_scale: (e.g. 2.0)
- A scale for gradient-clipping: if specified, the normalized gradients
- over the whole model will be clipped to have 2-norm equal to
- `clipping_scale` times the median 2-norm over the most recent period
- of `clipping_update_period` minibatches. By "normalized gradients",
- we mean after multiplying by the rms parameter value for this tensor
- [for non-scalars]; this is appropriate because our update is scaled
- by this quantity.
- betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
- Must satisfy 0 < beta <= beta2 < 1.
- scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
- scale of each parameter tensor and scalar parameters of the mode..
- If each parameter were decomposed
- as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
- would be a the scaling factor on the learning rate of p_scale.
- eps: A general-purpose epsilon to prevent division by zero
- param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
- parameter tensor to be >= this value)
- param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
- parameter tensor to be <= this value)
- scalar_max: Maximum absolute value for scalar parameters (applicable if your
- model has any parameters with numel() == 1).
- size_update_period: The periodicity, in steps, with which we update the size (scale)
- of the parameter tensor. This is provided to save a little time
- in the update.
- clipping_update_period: if clipping_scale is specified, this is the period
- """
-
- def __init__(
- self,
- params,
- lr=3e-02,
- clipping_scale=None,
- betas=(0.9, 0.98),
- scalar_lr_scale=0.1,
- eps=1.0e-08,
- param_min_rms=1.0e-05,
- param_max_rms=3.0,
- scalar_max=10.0,
- size_update_period=4,
- clipping_update_period=100,
- ):
-
- defaults = dict(
- lr=lr,
- clipping_scale=clipping_scale,
- betas=betas,
- scalar_lr_scale=scalar_lr_scale,
- eps=eps,
- param_min_rms=param_min_rms,
- param_max_rms=param_max_rms,
- scalar_max=scalar_max,
- size_update_period=size_update_period,
- clipping_update_period=clipping_update_period,
- )
-
- # If params only contains parameters or group of parameters,
- # i.e when parameter names are not given,
- # this flag will be set to False in funciton _get_names_of_parameters.
- self.show_dominant_parameters = True
- param_groups, parameters_names = self._get_names_of_parameters(params)
- super(ScaledAdam, self).__init__(param_groups, defaults)
- assert len(self.param_groups) == len(parameters_names)
- self.parameters_names = parameters_names
-
- def _get_names_of_parameters(
- self, params_or_named_params
- ) -> Tuple[List[Dict], List[List[str]]]:
- """
- Args:
- params_or_named_params: according to the way ScaledAdam is initialized in train.py,
- this argument could be one of following 4 cases,
- case 1, a generator of parameter, e.g.:
- optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
-
- case 2, a list of parameter groups with different config, e.g.:
- model_param_groups = [
- {'params': model.encoder.parameters(), 'lr': 0.05},
- {'params': model.decoder.parameters(), 'lr': 0.01},
- {'params': model.joiner.parameters(), 'lr': 0.03},
- ]
- optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
-
- case 3, a generator of named_parameter, e.g.:
- optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
-
- case 4, a list of named_parameter groups with different config, e.g.:
- model_named_param_groups = [
- {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
- {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
- {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
- ]
- optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
-
- For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
- For case 3 and case 4, firstly, names and params are extracted from input named_params,
- then, these extracted params are used to initialize the underlying torch.optimizer,
- and these extracted names are mainly used by function
- `_show_gradient_dominating_parameter`
-
- Returns:
- Returns a tuple containing 2 elements:
- - `param_groups` with type List[Dict], each Dict element is a parameter group.
- An example of `param_groups` could be:
- [
- {'params': `one iterable of Parameter`, 'lr': 0.05},
- {'params': `another iterable of Parameter`, 'lr': 0.08},
- {'params': `a third iterable of Parameter`, 'lr': 0.1},
- ]
- - `param_gruops_names` with type List[List[str]],
- each `List[str]` is for a group['params'] in param_groups,
- and each `str` is the name of a parameter.
- A dummy name "foo" is related to each parameter,
- if input are params without names, i.e. case 1 or case 2.
- """
- # variable naming convention in this function:
- # p is short for param.
- # np is short for named_param.
- # p_or_np is short for param_or_named_param.
- # cur is short for current.
- # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
- # groups is a List[group]
-
- iterable_or_groups = list(params_or_named_params)
- if len(iterable_or_groups) == 0:
- raise ValueError("optimizer got an empty parameter list")
-
- # The first value of returned tuple. A list of dicts containing at
- # least 'params' as a key.
- param_groups = []
-
- # The second value of returned tuple,
- # a List[List[str]], each sub-List is for a group.
- param_groups_names = []
-
- if not isinstance(iterable_or_groups[0], dict):
- # case 1 or case 3,
- # the input is an iterable of parameter or named parameter.
- param_iterable_cur_group = []
- param_names_cur_group = []
- for p_or_np in iterable_or_groups:
- if isinstance(p_or_np, tuple):
- # case 3
- name, param = p_or_np
- else:
- # case 1
- assert isinstance(p_or_np, torch.Tensor)
- param = p_or_np
- # Assign a dummy name as a placeholder
- name = "foo"
- self.show_dominant_parameters = False
- param_iterable_cur_group.append(param)
- param_names_cur_group.append(name)
- param_groups.append({"params": param_iterable_cur_group})
- param_groups_names.append(param_names_cur_group)
- else:
- # case 2 or case 4
- # the input is groups of parameter or named parameter.
- for cur_group in iterable_or_groups:
- assert "named_params" in cur_group
- name_list = [x[0] for x in cur_group["named_params"]]
- p_list = [x[1] for x in cur_group["named_params"]]
- del cur_group["named_params"]
- cur_group["params"] = p_list
- param_groups.append(cur_group)
- param_groups_names.append(name_list)
-
- return param_groups, param_groups_names
-
- def __setstate__(self, state):
- super(ScaledAdam, self).__setstate__(state)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- batch = True
-
- for group, group_params_names in zip(self.param_groups, self.parameters_names):
-
- with self.batched_params(group["params"], group_params_names) as batches:
-
- # batches is list of pairs (stacked_param, state). stacked_param is like
- # a regular parameter, and will have a .grad, but the 1st dim corresponds to
- # a stacking dim, it is not a real dim.
-
- if (
- len(batches[0][1]) == 0
- ): # if len(first state) == 0: not yet initialized
- clipping_scale = 1
- else:
- clipping_scale = self._get_clipping_scale(group, batches)
-
- for p, state, _ in batches:
- # Perform optimization step.
- # grad is not going to be None, we handled that when creating the batches.
- grad = p.grad
- if grad.is_sparse:
- raise RuntimeError(
- "ScaledAdam optimizer does not support sparse gradients"
- )
- # State initialization
- if len(state) == 0:
- self._init_state(group, p, state)
-
- self._step_one_batch(group, p, state, clipping_scale)
-
- return loss
-
- def _init_state(self, group: dict, p: Tensor, state: dict):
- """
- Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
- is actually the batch dimension, corresponding to batched-together
- parameters of a given shape.
-
-
- Args:
- group: Dict to look up configuration values.
- p: The parameter that we are initializing the state for
- state: Dict from string to whatever state we are initializing
- """
- size_update_period = group["size_update_period"]
-
- state["step"] = 0
-
- kwargs = {"device": p.device, "dtype": p.dtype}
-
- # 'delta' implements conventional momentum. There are
- # several different kinds of update going on, so rather than
- # compute "exp_avg" like in Adam, we store and decay a
- # parameter-change "delta", which combines all forms of
- # update. this is equivalent to how it's done in Adam,
- # except for the first few steps.
- state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
-
- batch_size = p.shape[0]
- numel = p.numel() // batch_size
-
- if numel > 1:
- # "param_rms" just periodically records the scalar root-mean-square value of
- # the parameter tensor.
- # it has a shape like (batch_size, 1, 1, 1, 1)
- param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
- state["param_rms"] = param_rms
-
- state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
- state["scale_grads"] = torch.zeros(
- size_update_period, *param_rms.shape, **kwargs
- )
-
- # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
- state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
-
- def _get_clipping_scale(
- self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
- ) -> float:
- """
- Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
- by this amount before applying the rest of the update.
-
- Args:
- group: the parameter group, an item in self.param_groups
- tuples: a list of tuples of (param, state, param_names)
- where param is a batched set of parameters,
- with a .grad (1st dim is batch dim)
- and state is the state-dict where optimization parameters are kept.
- param_names is a List[str] while each str is name for a parameter
- in batched set of parameters "param".
- """
- assert len(tuples) >= 1
- clipping_scale = group["clipping_scale"]
- (first_p, first_state, _) = tuples[0]
- step = first_state["step"]
- if clipping_scale is None or step == 0:
- # no clipping. return early on step == 0 because the other
- # parameters' state won't have been initialized yet.
- return 1.0
- clipping_update_period = group["clipping_update_period"]
- scalar_lr_scale = group["scalar_lr_scale"]
-
- tot_sumsq = torch.tensor(0.0, device=first_p.device)
- for (p, state, param_names) in tuples:
- grad = p.grad
- if grad.is_sparse:
- raise RuntimeError(
- "ScaledAdam optimizer does not support sparse gradients"
- )
- if p.numel() == p.shape[0]: # a batch of scalars
- tot_sumsq += (grad**2).sum() * (
- scalar_lr_scale**2
- ) # sum() to change shape [1] to []
- else:
- tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
-
- tot_norm = tot_sumsq.sqrt()
- if "model_norms" not in first_state:
- first_state["model_norms"] = torch.zeros(
- clipping_update_period, device=p.device
- )
- first_state["model_norms"][step % clipping_update_period] = tot_norm
-
- irregular_estimate_steps = [
- i for i in [10, 20, 40] if i < clipping_update_period
- ]
- if step % clipping_update_period == 0 or step in irregular_estimate_steps:
- # Print some stats.
- # We don't reach here if step == 0 because we would have returned
- # above.
- sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
- if step in irregular_estimate_steps:
- sorted_norms = sorted_norms[-step:]
- num_norms = sorted_norms.numel()
- quartiles = []
- for n in range(0, 5):
- index = min(num_norms - 1, (num_norms // 4) * n)
- quartiles.append(sorted_norms[index].item())
-
- median = quartiles[2]
- if median - median != 0:
- raise RuntimeError("Too many grads were not finite")
- threshold = clipping_scale * median
- if step in irregular_estimate_steps:
- # use larger thresholds on first few steps of estimating threshold,
- # as norm may be changing rapidly.
- threshold = threshold * 2.0
- first_state["model_norm_threshold"] = threshold
- percent_clipped = (
- first_state["num_clipped"] * 100.0 / num_norms
- if "num_clipped" in first_state
- else 0.0
- )
- first_state["num_clipped"] = 0
- quartiles = " ".join(["%.3e" % x for x in quartiles])
- logging.warn(
- f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
- f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
- )
-
- try:
- model_norm_threshold = first_state["model_norm_threshold"]
- except KeyError:
- return 1.0 # threshold has not yet been set.
-
- ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
- if ans != ans: # e.g. ans is nan
- ans = 0.0
- if ans < 1.0:
- first_state["num_clipped"] += 1
- if ans < 0.1:
- logging.warn(
- f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
- )
- if self.show_dominant_parameters:
- assert p.shape[0] == len(param_names)
- self._show_gradient_dominating_parameter(
- tuples, tot_sumsq, group["scalar_lr_scale"]
- )
-
- if ans == 0.0:
- for (p, state, param_names) in tuples:
- p.grad.zero_() # get rid of infinity()
-
- return ans
-
- def _show_gradient_dominating_parameter(
- self,
- tuples: List[Tuple[Tensor, dict, List[str]]],
- tot_sumsq: Tensor,
- scalar_lr_scale: float,
- ):
- """
- Show information of parameter which dominates tot_sumsq.
-
- Args:
- tuples: a list of tuples of (param, state, param_names)
- where param is a batched set of parameters,
- with a .grad (1st dim is batch dim)
- and state is the state-dict where optimization parameters are kept.
- param_names is a List[str] while each str is name for a parameter
- in batched set of parameters "param".
- tot_sumsq: sumsq of all parameters. Though it's could be calculated
- from tuples, we still pass it to save some time.
- """
- all_sumsq_orig = {}
- for (p, state, batch_param_names) in tuples:
- # p is a stacked batch parameters.
- batch_grad = p.grad
- if p.numel() == p.shape[0]: # a batch of scalars
- # Dummy values used by following `zip` statement.
- batch_rms_orig = torch.full(
- p.shape, scalar_lr_scale, device=batch_grad.device
- )
- else:
- batch_rms_orig = state["param_rms"]
- batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
- if batch_grad.ndim > 1:
- # need to guard it with if-statement because sum() sums over
- # all dims if dim == ().
- batch_sumsq_orig = batch_sumsq_orig.sum(
- dim=list(range(1, batch_grad.ndim))
- )
- for name, sumsq_orig, rms, grad in zip(
- batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
- ):
-
- proportion_orig = sumsq_orig / tot_sumsq
- all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
-
- sorted_by_proportion = {
- k: v
- for k, v in sorted(
- all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
- )
- }
- dominant_param_name = next(iter(sorted_by_proportion))
- (
- dominant_proportion,
- dominant_sumsq,
- dominant_rms,
- dominant_grad,
- ) = sorted_by_proportion[dominant_param_name]
- logging.warn(
- f"Parameter dominating tot_sumsq {dominant_param_name}"
- f" with proportion {dominant_proportion:.2f},"
- f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
- f"={dominant_sumsq:.3e},"
- f" grad_sumsq={(dominant_grad**2).sum():.3e},"
- f" orig_rms_sq={(dominant_rms**2).item():.3e}"
- )
-
- def _step_one_batch(
- self, group: dict, p: Tensor, state: dict, clipping_scale: float
- ):
- """
- Do the step for one parameter, which is actually going to be a batch of
- `real` parameters, with dim 0 as the batch dim.
- Args:
- group: dict to look up configuration values
- p: parameter to update (actually multiple parameters stacked together
- as a batch)
- state: state-dict for p, to look up the optimizer state
- """
- lr = group["lr"]
- size_update_period = group["size_update_period"]
- beta1 = group["betas"][0]
-
- grad = p.grad
- if clipping_scale != 1.0:
- grad *= clipping_scale
- step = state["step"]
- delta = state["delta"]
-
- delta.mul_(beta1)
- batch_size = p.shape[0]
- numel = p.numel() // batch_size
- if numel > 1:
- # Update the size/scale of p, and set param_rms
- scale_grads = state["scale_grads"]
- scale_grads[step % size_update_period] = (p * grad).sum(
- dim=list(range(1, p.ndim)), keepdim=True
- )
- if step % size_update_period == size_update_period - 1:
- param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
- param_rms.copy_(
- (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
- )
- if step > 0:
- # self._size_update() learns the overall scale on the
- # parameter, by shrinking or expanding it.
- self._size_update(group, scale_grads, p, state)
-
- if numel == 1:
- # For parameters with 1 element we just use regular Adam.
- # Updates delta.
- self._step_scalar(group, p, state)
- else:
- self._step(group, p, state)
-
- state["step"] = step + 1
-
- def _size_update(
- self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
- ) -> None:
- """
- Called only where p.numel() > 1, this updates the scale of the parameter.
- If we imagine: p = underlying_param * scale.exp(), and we are doing
- gradient descent on underlying param and on scale, this function does the update
- on `scale`.
-
- Args:
- group: dict to look up configuration values
- scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
- grads w.r.t. the scales.
- p: The parameter to update
- state: The state-dict of p
- """
-
- param_rms = state["param_rms"]
- beta1, beta2 = group["betas"]
- size_lr = group["lr"] * group["scalar_lr_scale"]
- param_min_rms = group["param_min_rms"]
- param_max_rms = group["param_max_rms"]
- eps = group["eps"]
- step = state["step"]
- batch_size = p.shape[0]
-
- size_update_period = scale_grads.shape[0]
- # correct beta2 for the size update period: we will have
- # faster decay at this level.
- beta2_corr = beta2**size_update_period
-
- scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
- scale_exp_avg_sq.mul_(beta2_corr).add_(
- (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
- alpha=1 - beta2_corr,
- ) # shape is (batch_size, 1, 1, ...)
-
- # The 1st time we reach here is when size_step == 1.
- size_step = (step + 1) // size_update_period
- bias_correction2 = 1 - beta2_corr**size_step
- # we don't bother with bias_correction1; this will help prevent divergence
- # at the start of training.
-
- denom = scale_exp_avg_sq.sqrt() + eps
-
- scale_step = (
- -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
- )
-
- is_too_small = param_rms < param_min_rms
-
- # when the param gets too small, just don't shrink it any further.
- scale_step.masked_fill_(is_too_small, 0.0)
-
- # and ensure the parameter rms after update never exceeds param_max_rms.
- # We have to look at the trained model for parameters at or around the
- # param_max_rms, because sometimes they can indicate a problem with the
- # topology or settings.
- scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
-
- delta = state["delta"]
- # the factor of (1-beta1) relates to momentum.
- delta.add_(p * scale_step, alpha=(1 - beta1))
-
- def _step(self, group: dict, p: Tensor, state: dict):
- """
- This function does the core update of self.step(), in the case where the members of
- the batch have more than 1 element.
-
- Args:
- group: A dict which will be used to look up configuration values
- p: The parameter to be updated
- grad: The grad of p
- state: The state-dict corresponding to parameter p
-
- This function modifies p.
- """
- grad = p.grad
- lr = group["lr"]
- beta1, beta2 = group["betas"]
- eps = group["eps"]
- param_min_rms = group["param_min_rms"]
- step = state["step"]
-
- exp_avg_sq = state["exp_avg_sq"]
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
-
- this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
- bias_correction2 = 1 - beta2 ** (this_step + 1)
- if bias_correction2 < 0.99:
- # note: not in-place.
- exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
-
- denom = exp_avg_sq.sqrt()
- denom += eps
- grad = grad / denom
-
- alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
-
- delta = state["delta"]
- delta.add_(grad * alpha)
- p.add_(delta)
-
- def _step_scalar(self, group: dict, p: Tensor, state: dict):
- """
- A simplified form of the core update for scalar tensors, where we cannot get a good
- estimate of the parameter rms.
- """
- beta1, beta2 = group["betas"]
- scalar_max = group["scalar_max"]
- eps = group["eps"]
- lr = group["lr"] * group["scalar_lr_scale"]
- grad = p.grad
-
- exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
-
- # bias_correction2 is like in Adam. Don't bother with bias_correction1;
- # slower update at the start will help stability anyway.
- bias_correction2 = 1 - beta2 ** (state["step"] + 1)
- denom = (exp_avg_sq / bias_correction2).sqrt() + eps
-
- delta = state["delta"]
- delta.add_(grad / denom, alpha=-lr * (1 - beta1))
- p.clamp_(min=-scalar_max, max=scalar_max)
- p.add_(delta)
-
-
-class LRScheduler(object):
- """
- Base-class for learning rate schedulers where the learning-rate depends on both the
- batch and the epoch.
- """
-
- def __init__(self, optimizer: Optimizer, verbose: bool = False):
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
- self.optimizer = optimizer
- self.verbose = verbose
-
- for group in optimizer.param_groups:
- group.setdefault("base_lr", group["lr"])
-
- self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
-
- self.epoch = 0
- self.batch = 0
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- """
- return {
- "base_lrs": self.base_lrs,
- "epoch": self.epoch,
- "batch": self.batch,
- }
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- self.__dict__.update(state_dict)
-
- def get_last_lr(self) -> List[float]:
- """Return last computed learning rate by current scheduler. Will be a list of float."""
- return self._last_lr
-
- def get_lr(self):
- # Compute list of learning rates from self.epoch and self.batch and
- # self.base_lrs; this must be overloaded by the user.
- # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
- raise NotImplementedError
-
- def step_batch(self, batch: Optional[int] = None) -> None:
- # Step the batch index, or just set it. If `batch` is specified, it
- # must be the batch index from the start of training, i.e. summed over
- # all epochs.
- # You can call this in any order; if you don't provide 'batch', it should
- # of course be called once per batch.
- if batch is not None:
- self.batch = batch
- else:
- self.batch = self.batch + 1
- self._set_lrs()
-
- def step_epoch(self, epoch: Optional[int] = None):
- # Step the epoch index, or just set it. If you provide the 'epoch' arg,
- # you should call this at the start of the epoch; if you don't provide the 'epoch'
- # arg, you should call it at the end of the epoch.
- if epoch is not None:
- self.epoch = epoch
- else:
- self.epoch = self.epoch + 1
- self._set_lrs()
-
- def _set_lrs(self):
- values = self.get_lr()
- assert len(values) == len(self.optimizer.param_groups)
-
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
- param_group, lr = data
- param_group["lr"] = lr
- self.print_lr(self.verbose, i, lr)
- self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
-
- def print_lr(self, is_verbose, group, lr):
- """Display the current learning rate."""
- if is_verbose:
- logging.warn(
- f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
- f" of group {group} to {lr:.4e}."
- )
-
-
-class Eden(LRScheduler):
- """
- Eden scheduler.
- The basic formula (before warmup) is:
- lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
- (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
- where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
- and then stays constant at 1.
-
- If you don't have the concept of epochs, or one epoch takes a very long time,
- you can replace the notion of 'epoch' with some measure of the amount of data
- processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
- some measure representing "quite a lot of data": say, one fifth or one third
- of an entire training run, but it doesn't matter much. You could also use
- Eden2 which has only the notion of batches.
-
- We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
-
- Args:
- optimizer: the optimizer to change the learning rates on
- lr_batches: the number of batches after which we start significantly
- decreasing the learning rate, suggest 5000.
- lr_epochs: the number of epochs after which we start significantly
- decreasing the learning rate, suggest 6 if you plan to do e.g.
- 20 to 40 epochs, but may need smaller number if dataset is huge
- and you will do few epochs.
- """
-
- def __init__(
- self,
- optimizer: Optimizer,
- lr_batches: Union[int, float],
- lr_epochs: Union[int, float],
- warmup_batches: Union[int, float] = 500.0,
- warmup_start: float = 0.5,
- verbose: bool = False,
- ):
- super(Eden, self).__init__(optimizer, verbose)
- self.lr_batches = lr_batches
- self.lr_epochs = lr_epochs
- self.warmup_batches = warmup_batches
-
- assert 0.0 <= warmup_start <= 1.0, warmup_start
- self.warmup_start = warmup_start
-
- def get_lr(self):
- factor = (
- (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
- ) ** -0.25 * (
- ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
- )
- warmup_factor = (
- 1.0
- if self.batch >= self.warmup_batches
- else self.warmup_start
- + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
- # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
- )
-
- return [x * factor * warmup_factor for x in self.base_lrs]
-
-
-class Eden2(LRScheduler):
- """
- Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
- only batches.
-
- The basic formula (before warmup) is:
- lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
-
- where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
- and then stays constant at 1.
-
-
- E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
-
- Args:
- optimizer: the optimizer to change the learning rates on
- lr_batches: the number of batches after which we start significantly
- decreasing the learning rate, suggest 5000.
- """
-
- def __init__(
- self,
- optimizer: Optimizer,
- lr_batches: Union[int, float],
- warmup_batches: Union[int, float] = 500.0,
- warmup_start: float = 0.5,
- verbose: bool = False,
- ):
- super().__init__(optimizer, verbose)
- self.lr_batches = lr_batches
- self.warmup_batches = warmup_batches
-
- assert 0.0 <= warmup_start <= 1.0, warmup_start
- self.warmup_start = warmup_start
-
- def get_lr(self):
- factor = (
- (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
- ) ** -0.5
- warmup_factor = (
- 1.0
- if self.batch >= self.warmup_batches
- else self.warmup_start
- + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
- # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
- )
-
- return [x * factor * warmup_factor for x in self.base_lrs]
-
-
-def _test_eden():
- m = torch.nn.Linear(100, 100)
- optim = ScaledAdam(m.parameters(), lr=0.03)
-
- scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
-
- for epoch in range(10):
- scheduler.step_epoch(epoch) # sets epoch to `epoch`
-
- for step in range(20):
- x = torch.randn(200, 100).detach()
- x.requires_grad = True
- y = m(x)
- dy = torch.randn(200, 100).detach()
- f = (y * dy).sum()
- f.backward()
-
- optim.step()
- scheduler.step_batch()
- optim.zero_grad()
-
- logging.info(f"last lr = {scheduler.get_last_lr()}")
- logging.info(f"state dict = {scheduler.state_dict()}")
-
-
-# This is included mostly as a baseline for ScaledAdam.
-class Eve(Optimizer):
- """
- Implements Eve algorithm. This is a modified version of AdamW with a special
- way of setting the weight-decay / shrinkage-factor, which is designed to make the
- rms of the parameters approach a particular target_rms (default: 0.1). This is
- for use with networks with 'scaled' versions of modules (see scaling.py), which
- will be close to invariant to the absolute scale on the parameter matrix.
-
- The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
- The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
- Eve is unpublished so far.
-
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 3e-4;
- this value means that the weight would decay significantly after
- about 3k minibatches. Is not multiplied by learning rate, but
- is conditional on RMS-value of parameter being > target_rms.
- target_rms (float, optional): target root-mean-square value of
- parameters, if they fall below this we will stop applying weight decay.
-
-
- .. _Adam: A Method for Stochastic Optimization:
- https://arxiv.org/abs/1412.6980
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
-
- def __init__(
- self,
- params,
- lr=1e-3,
- betas=(0.9, 0.98),
- eps=1e-8,
- weight_decay=1e-3,
- target_rms=0.1,
- ):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- if not 0 <= weight_decay <= 0.1:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- if not 0 < target_rms <= 10.0:
- raise ValueError("Invalid target_rms value: {}".format(target_rms))
- defaults = dict(
- lr=lr,
- betas=betas,
- eps=eps,
- weight_decay=weight_decay,
- target_rms=target_rms,
- )
- super(Eve, self).__init__(params, defaults)
-
- def __setstate__(self, state):
- super(Eve, self).__setstate__(state)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- for group in self.param_groups:
- for p in group["params"]:
- if p.grad is None:
- continue
-
- # Perform optimization step
- grad = p.grad
- if grad.is_sparse:
- raise RuntimeError("AdamW does not support sparse gradients")
-
- state = self.state[p]
-
- # State initialization
- if len(state) == 0:
- state["step"] = 0
- # Exponential moving average of gradient values
- state["exp_avg"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- # Exponential moving average of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
-
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
-
- beta1, beta2 = group["betas"]
-
- state["step"] += 1
- bias_correction1 = 1 - beta1 ** state["step"]
- bias_correction2 = 1 - beta2 ** state["step"]
-
- # Decay the first and second moment running average coefficient
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
- denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
- group["eps"]
- )
-
- step_size = group["lr"] / bias_correction1
- target_rms = group["target_rms"]
- weight_decay = group["weight_decay"]
-
- if p.numel() > 1:
- # avoid applying this weight-decay on "scaling factors"
- # (which are scalar).
- is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
- p.mul_(1 - (weight_decay * is_above_target_rms))
-
- p.addcdiv_(exp_avg, denom, value=-step_size)
-
- if random.random() < 0.0005:
- step = (exp_avg / denom) * step_size
- logging.info(
- f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
- )
-
- return loss
-
-
-def _test_scaled_adam(hidden_dim: int):
- import timeit
-
- from scaling import ScaledLinear
-
- E = 100
- B = 4
- T = 2
- logging.info("in test_eve_cain")
- # device = torch.device('cuda')
- device = torch.device("cpu")
- dtype = torch.float32
-
- fix_random_seed(42)
- # these input_magnitudes and output_magnitudes are to test that
- # Abel is working as we expect and is able to adjust scales of
- # different dims differently.
- input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
- output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
-
- for iter in [1, 0]:
- fix_random_seed(42)
- Linear = torch.nn.Linear if iter == 0 else ScaledLinear
-
- m = torch.nn.Sequential(
- Linear(E, hidden_dim),
- torch.nn.PReLU(),
- Linear(hidden_dim, hidden_dim),
- torch.nn.PReLU(),
- Linear(hidden_dim, E),
- ).to(device)
-
- train_pairs = [
- (
- 100.0
- * torch.randn(B, T, E, device=device, dtype=dtype)
- * input_magnitudes,
- torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
- )
- for _ in range(20)
- ]
-
- if iter == 0:
- optim = Eve(m.parameters(), lr=0.003)
- elif iter == 1:
- optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
- scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
-
- start = timeit.default_timer()
- avg_loss = 0.0
- for epoch in range(180):
- scheduler.step_epoch()
- # if epoch == 100 and iter in [2,3]:
- # optim.reset_speedup() # check it doesn't crash.
-
- # if epoch == 130:
- # opts = diagnostics.TensorDiagnosticOptions(
- # 512
- # ) # allow 4 megabytes per sub-module
- # diagnostic = diagnostics.attach_diagnostics(m, opts)
-
- for n, (x, y) in enumerate(train_pairs):
- y_out = m(x)
- loss = ((y_out - y) ** 2).mean() * 100.0
- if epoch == 0 and n == 0:
- avg_loss = loss.item()
- else:
- avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
- if n == 0 and epoch % 5 == 0:
- # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
- # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
- # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
- # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
- # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
- # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
- # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
- # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
- lr = scheduler.get_last_lr()[0]
- logging.info(
- f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
- ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
- loss.log().backward()
- optim.step()
- optim.zero_grad()
- scheduler.step_batch()
-
- # diagnostic.print_diagnostics()
-
- stop = timeit.default_timer()
- logging.info(f"Iter={iter}, Time taken: {stop - start}")
-
- logging.info(f"last lr = {scheduler.get_last_lr()}")
- # logging.info("state dict = ", scheduler.state_dict())
- # logging.info("optim state_dict = ", optim.state_dict())
- logging.info(f"input_magnitudes = {input_magnitudes}")
- logging.info(f"output_magnitudes = {output_magnitudes}")
-
-
-if __name__ == "__main__":
- torch.set_num_threads(1)
- torch.set_num_interop_threads(1)
- logging.getLogger().setLevel(logging.INFO)
- import subprocess
-
- s = subprocess.check_output(
- "git status -uno .; git log -1; git diff HEAD .", shell=True
- )
- logging.info(s)
- import sys
-
- if len(sys.argv) > 1:
- hidden_dim = int(sys.argv[1])
- else:
- hidden_dim = 200
-
- _test_scaled_adam(hidden_dim)
- _test_eden()
diff --git a/egs/wenetspeech/ASR/whisper/optim.py b/egs/wenetspeech/ASR/whisper/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/whisper/requirements.txt b/egs/wenetspeech/ASR/whisper/requirements.txt
deleted file mode 100755
index 0708f2344..000000000
--- a/egs/wenetspeech/ASR/whisper/requirements.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-k2
-kaldialign
-git+https://github.com/lhotse-speech/lhotse
-sentencepiece
-tensorboard
-librosa
-git+https://github.com/yuekaizhang/whisper.git
-zhconv
-WeTextProcessing
-deepspeed
diff --git a/egs/wenetspeech/ASR/whisper/requirements.txt b/egs/wenetspeech/ASR/whisper/requirements.txt
new file mode 120000
index 000000000..744bf8bb6
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/requirements.txt
@@ -0,0 +1 @@
+../../../aishell/ASR/whisper/requirements.txt
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py
deleted file mode 100644
index 5bfbdce3b..000000000
--- a/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import torch
-import torch.nn.functional as F
-import whisper
-
-
-def forward(self, x: torch.Tensor):
- """
- x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
- the mel spectrogram of the audio
- """
- x = F.gelu(self.conv1(x))
- x = F.gelu(self.conv2(x))
- x = x.permute(0, 2, 1)
-
- x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
-
- for block in self.blocks:
- x = block(x)
-
- x = self.ln_post(x)
- return x
-
-
-def replace_whisper_encoder_forward():
- """
- This function monkey patches the forward method of the whisper encoder.
- To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
- """
- whisper.model.AudioEncoder.forward = forward
diff --git a/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py
new file mode 120000
index 000000000..2a7808921
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/whisper_encoder_forward_monkey_patch.py
@@ -0,0 +1 @@
+../../../aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
\ No newline at end of file