From fcd25bdfffb17b64a0c9d98250ae6021338e573f Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 6 Feb 2022 18:22:56 +0800 Subject: [PATCH 001/234] Fix torch.nn.Embedding error for torch below 1.8.0 --- egs/librispeech/ASR/transducer/beam_search.py | 4 +++- egs/librispeech/ASR/transducer/model.py | 1 + egs/librispeech/ASR/transducer_lstm/beam_search.py | 4 +++- egs/librispeech/ASR/transducer_lstm/model.py | 1 + egs/librispeech/ASR/transducer_stateless/beam_search.py | 2 +- egs/librispeech/ASR/transducer_stateless/model.py | 1 + 6 files changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index f45d06ce9..11032f31a 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index fa0b2dd68..8305248c9 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -99,6 +99,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index dfc22fcf8..3531a9633 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a2..31843b60e 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -101,6 +101,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=sos_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 341c74fab..1cce48235 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -48,7 +48,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 7aac290d9..17b5f63e5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -93,6 +93,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out = self.decoder(sos_y_padded) From 8f8ec223a715776f8e92a6daaa082627deae3cc8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Feb 2022 21:18:40 +0800 Subject: [PATCH 002/234] Changes to fbank computation, use lilcom chunky writer --- egs/librispeech/ASR/local/compute_fbank_librispeech.py | 4 ++-- egs/librispeech/ASR/local/compute_fbank_musan.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index b26034eb2..5c33ff8be 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -85,7 +85,7 @@ def compute_fbank_librispeech(): # when an executor is specified, make more partitions num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index d44524e70..f5911746b 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -82,7 +82,7 @@ def compute_fbank_musan(): storage_path=f"{output_dir}/feats_musan", num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) ) musan_cuts.to_json(musan_cuts_path) From 48a764eccf30e0d7a178563255a648d292a19673 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Feb 2022 21:19:37 +0800 Subject: [PATCH 003/234] Add min in q,k,v of attention --- .../ASR/transducer_stateless/conformer.py | 51 +++++++++++++++++-- .../ASR/transducer_stateless/decoder.py | 1 + 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9..f803ee9b6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -440,8 +440,19 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + + self.in_proj_floor_scale = 10.0 # so it learns fast enough.. + with torch.no_grad(): + in_proj_floor = torch.Tensor(3 * embed_dim) + # key and query get a floor value quite close to zero. + in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale + # value gets very low floor, may be close to having no effectc. + in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale + self.in_proj_floor = nn.Parameter(in_proj_floor) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -526,6 +537,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, + in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale ) def rel_shift(self, x: Tensor) -> Tensor: @@ -570,6 +582,7 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, + in_proj_floor: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -629,9 +642,12 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear( + _qkv = nn.functional.linear( query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) + ) + if in_proj_floor is not None: + _qkv = torch.maximum(_qkv, in_proj_floor) + q, k, v = _qkv.chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -643,6 +659,10 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + q = torch.maximum(q, _f) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -650,7 +670,11 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + _kv = nn.functional.linear(key, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + _kv = torch.maximum(_kv, _f) + k, v = _kv.chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -661,6 +685,10 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + q = torch.maximum(q, _f) + # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -670,6 +698,9 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:_end] k = nn.functional.linear(key, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + k = torch.maximum(k, _f) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -679,6 +710,10 @@ class RelPositionMultiheadAttention(nn.Module): if _b is not None: _b = _b[_start:] v = nn.functional.linear(value, _w, _b) + if in_proj_floor is not None: + _f = in_proj_floor[_start:_end] + v = torch.maximum(v, _f) + if attn_mask is not None: assert ( @@ -918,3 +953,13 @@ class Swish(torch.nn.Module): def identity(x): return x + + +if __name__ == '__main__': + feature_dim = 50 + c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c(torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64)) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index c2c6552a9..003b03a2e 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -82,6 +82,7 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ + y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) From a859dcb20504e7b4bbc2ea9b1f1b28ad5f5e0757 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Feb 2022 12:14:48 +0800 Subject: [PATCH 004/234] Remove learnable offset, use relu instead. --- .../ASR/transducer_stateless/conformer.py | 46 +++---------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index f803ee9b6..c06335905 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -440,19 +440,8 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - - self.in_proj_floor_scale = 10.0 # so it learns fast enough.. - with torch.no_grad(): - in_proj_floor = torch.Tensor(3 * embed_dim) - # key and query get a floor value quite close to zero. - in_proj_floor[:2*embed_dim] = -0.2 / self.in_proj_floor_scale - # value gets very low floor, may be close to having no effectc. - in_proj_floor[2*embed_dim:] = -1.5 / self.in_proj_floor_scale - self.in_proj_floor = nn.Parameter(in_proj_floor) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - # linear transformation for positional encoding. self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d @@ -537,7 +526,6 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, - in_proj_floor=self.in_proj_floor*self.in_proj_floor_scale ) def rel_shift(self, x: Tensor) -> Tensor: @@ -582,7 +570,6 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - in_proj_floor: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -642,12 +629,7 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - _qkv = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ) - if in_proj_floor is not None: - _qkv = torch.maximum(_qkv, in_proj_floor) - q, k, v = _qkv.chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -658,10 +640,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - q = torch.maximum(q, _f) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -670,11 +649,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - _kv = nn.functional.linear(key, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - _kv = torch.maximum(_kv, _f) - k, v = _kv.chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -684,10 +659,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - q = torch.maximum(q, _f) + q = nn.functional.linear(query, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -697,10 +669,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - k = torch.maximum(k, _f) + k = nn.functional.linear(key, _w, _b).relu() # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -709,10 +678,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - if in_proj_floor is not None: - _f = in_proj_floor[_start:_end] - v = torch.maximum(v, _f) + v = nn.functional.linear(value, _w, _b).relu() if attn_mask is not None: From 3323cabf467324b5d8bc3b1247a37724cd778ed0 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Tue, 8 Feb 2022 14:25:31 +0800 Subject: [PATCH 005/234] Experiments based on SpecAugment change --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 213 +++++++++++++++++- 1 file changed, 211 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d03..e5fcc5893 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -28,7 +28,6 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, - SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader @@ -219,10 +218,11 @@ class LibriSpeechAsrDataModule: input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, + num_frame_masks=10, features_mask_size=27, num_feature_masks=2, frames_mask_size=100, + max_frames_mask_fraction=0.4, ) ) else: @@ -383,3 +383,212 @@ class LibriSpeechAsrDataModule: def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") + + +import math +import random +import numpy as np +from typing import Optional, Dict + +import torch + +from lhotse import CutSet + +class SpecAugment(torch.nn.Module): + """ + SpecAugment performs three augmentations: + - time warping of the feature matrix + - masking of ranges of features (frequency bands) + - masking of ranges of frames (time) + + The current implementation works with batches, but processes each example separately + in a loop rather than simultaneously to achieve different augmentation parameters for + each example. + """ + + def __init__( + self, + time_warp_factor: Optional[int] = 80, + num_feature_masks: int = 1, + features_mask_size: int = 13, + num_frame_masks: int = 1, + frames_mask_size: int = 70, + max_frames_mask_fraction: float = 0.2, + p=0.5, + ): + """ + SpecAugment's constructor. + + :param time_warp_factor: parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. + :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). + This is the ``F`` parameter from the SpecAugment paper. + :param num_frame_masks: how many frame (temporal) masks should be applied. Set to ``0`` to disable. + :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). + This is the ``T`` parameter from the SpecAugment paper. + :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length + of the utterance (or supervision segment). + This is the parameter denoted by ``p`` in the SpecAugment paper. + :param p: the probability of applying this transform. + It is different from ``p`` in the SpecAugment paper! + """ + super().__init__() + assert 0 <= p <= 1 + assert num_feature_masks >= 0 + assert num_frame_masks >= 0 + assert features_mask_size > 0 + assert frames_mask_size > 0 + self.time_warp_factor = time_warp_factor + self.num_feature_masks = num_feature_masks + self.features_mask_size = features_mask_size + self.num_frame_masks = num_frame_masks + self.frames_mask_size = frames_mask_size + self.max_frames_mask_fraction = max_frames_mask_fraction + self.p = p + + def forward( + self, + features: torch.Tensor, + supervision_segments: Optional[torch.IntTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Computes SpecAugment for a batch of feature matrices. + + Since the batch will usually already be padded, the user can optionally + provide a ``supervision_segments`` tensor that will be used to apply SpecAugment + only to selected areas of the input. The format of this input is described below. + + :param features: a batch of feature matrices with shape ``(B, T, F)``. + :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features`` -- there may be either + less or more than the batch size. + The second dimension encoder three kinds of information: + the sequence index of the corresponding feature matrix in `features`, + the start frame index, and the number of frames for each segment. + :return: an augmented tensor of shape ``(B, T, F)``. + """ + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of " "single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single(features[sequence_idx]) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = self._forward_single( + features[sequence_idx, start_frame:end_frame], warp=True, mask=False + ) + # ... and then time-mask the full feature matrices. Note that in this mode, + # it might happen that masks are applied to different sequences/examples + # than the time warping. + for sequence_idx in range(features.size(0)): + features[sequence_idx] = self._forward_single( + features[sequence_idx], warp=False, mask=True + ) + return features + + def _forward_single( + self, features: torch.Tensor, warp: bool = True, mask: bool = True + ) -> torch.Tensor: + """ + Apply SpecAugment to a single feature matrix of shape (T, F). + """ + if random.random() > self.p: + # Randomly choose whether this transform is applied + return features + if warp: + if self.time_warp_factor is not None and self.time_warp_factor >= 1: + features = time_warp(features, factor=self.time_warp_factor) + if mask: + from torchaudio.functional import mask_along_axis + + mean = features.mean() + for _ in range(self.num_feature_masks): + features = mask_along_axis( + features.unsqueeze(0), + mask_param=self.features_mask_size, + mask_value=mean, + axis=2, + ).squeeze(0) + for _ in range(self.num_frame_masks): + _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) + max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) + + features = mask_along_axis( + features.unsqueeze(0), + mask_param=max_mask_frames, + mask_value=mean, + axis=1, + ).squeeze(0) + return features + + def state_dict(self) -> Dict: + return dict( + time_warp_factor=self.time_warp_factor, + num_feature_masks=self.num_feature_masks, + features_mask_size=self.features_mask_size, + num_frame_masks=self.num_frame_masks, + frames_mask_size=self.frames_mask_size, + max_frames_mask_fraction=self.max_frames_mask_fraction, + p=self.p, + ) + + def load_state_dict(self, state_dict: Dict): + self.time_warp_factor = state_dict.get( + "time_warp_factor", self.time_warp_factor + ) + self.num_feature_masks = state_dict.get( + "num_feature_masks", self.num_feature_masks + ) + self.features_mask_size = state_dict.get( + "features_mask_size", self.features_mask_size + ) + self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) + self.frames_mask_size = state_dict.get( + "frames_mask_size", self.frames_mask_size + ) + self.max_frames_mask_fraction = state_dict.get( + "max_frames_mask_fraction", self.max_frames_mask_fraction + ) + self.p = state_dict.get("p", self.p) + + +def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: + """ + Time warping as described in the SpecAugment paper. + Implementation based on Espresso: + https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 + + :param features: input tensor of shape ``(T, F)`` + :param factor: time warping parameter. + :return: a warped tensor of shape ``(T, F)`` + """ + t = features.size(0) + if t - factor <= factor + 1: + return features + center = np.random.randint(factor + 1, t - factor) + warped = np.random.randint(center - factor, center + factor + 1) + if warped == center: + return features + features = features.unsqueeze(0).unsqueeze(0) + left = torch.nn.functional.interpolate( + features[:, :, :center, :], + size=(warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + right = torch.nn.functional.interpolate( + features[:, :, center:, :], + size=(t - warped, features.size(3)), + mode="bicubic", + align_corners=False, + ) + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) \ No newline at end of file From beaf5bfbab85108f32751d5590fddc642437fdb7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Feb 2022 19:42:23 +0800 Subject: [PATCH 006/234] Merge specaug change from Mingshuang. --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 950a88a35..5c447bc4b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp", + default="transducer_stateless/exp-100h-relu-specaug", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From bd36216e8cbc40b194e02e0e8d5bb86a3e60edf2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 8 Feb 2022 21:55:20 +0800 Subject: [PATCH 007/234] Use much more aggressive SpecAug setup --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index e5fcc5893..a5ab012e3 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -220,7 +220,7 @@ class LibriSpeechAsrDataModule: time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=10, features_mask_size=27, - num_feature_masks=2, + num_feature_masks=10, frames_mask_size=100, max_frames_mask_fraction=0.4, ) @@ -521,7 +521,7 @@ class SpecAugment(torch.nn.Module): _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) - + features = mask_along_axis( features.unsqueeze(0), mask_param=max_mask_frames, @@ -591,4 +591,4 @@ def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: mode="bicubic", align_corners=False, ) - return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) \ No newline at end of file + return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) From dd19a6a2b13a7c452f5910fb4a0e123910540302 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Feb 2022 12:02:19 +0800 Subject: [PATCH 008/234] Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a5ab012e3..11b07bd69 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -220,9 +220,9 @@ class LibriSpeechAsrDataModule: time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=10, features_mask_size=27, - num_feature_masks=10, + num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.4, + max_frames_mask_fraction=0.3, ) ) else: From 8aa50df4f0c5b6d1edb2e850364c32fd3c666aab Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Feb 2022 22:52:53 +0800 Subject: [PATCH 009/234] Change p=0.5->0.9, mask_fraction 0.3->0.2 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 3 ++- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 11b07bd69..7df7a3525 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -222,7 +222,8 @@ class LibriSpeechAsrDataModule: features_mask_size=27, num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.3, + max_frames_mask_fraction=0.2, + p=0.9 ) ) else: diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 5c447bc4b..136faca57 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaug", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From c170c53006a7822e06832e960b630bcae964893a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Feb 2022 14:59:14 +0800 Subject: [PATCH 010/234] Change p=0.9 to p=0.8 in SpecAug --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 7df7a3525..044ad4fc6 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -223,7 +223,7 @@ class LibriSpeechAsrDataModule: num_feature_masks=2, frames_mask_size=100, max_frames_mask_fraction=0.2, - p=0.9 + p=0.8 ) ) else: diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 136faca57..62cd1e764 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 4cd2c02fffac5cba4b0ca02d414fecbda90f7104 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Feb 2022 15:53:11 +0800 Subject: [PATCH 011/234] Fix num_time_masks code; revert 0.8 to 0.9 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 11 +++++------ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 044ad4fc6..df2e48421 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -223,7 +223,7 @@ class LibriSpeechAsrDataModule: num_feature_masks=2, frames_mask_size=100, max_frames_mask_fraction=0.2, - p=0.8 + p=0.9 ) ) else: @@ -518,11 +518,10 @@ class SpecAugment(torch.nn.Module): mask_value=mean, axis=2, ).squeeze(0) - for _ in range(self.num_frame_masks): - _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) - num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) - max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) - + _max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) + num_frame_masks = min(self.num_frame_masks, math.ceil(_max_tot_mask_frames / self.frames_mask_size)) + max_mask_frames = min(self.frames_mask_size, _max_tot_mask_frames // num_frame_masks) + for _ in range(num_frame_masks): features = mask_along_axis( features.unsqueeze(0), mask_param=max_mask_frames, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 62cd1e764..4bd85ca2e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.8_0.2", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From d187ad8b739b4df3dbd1940b768393f0eed91a8e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Feb 2022 16:24:17 +0800 Subject: [PATCH 012/234] Change max_frames from 0.2 to 0.15 --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 +- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index df2e48421..c1b16bcf0 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -222,7 +222,7 @@ class LibriSpeechAsrDataModule: features_mask_size=27, num_feature_masks=2, frames_mask_size=100, - max_frames_mask_fraction=0.2, + max_frames_mask_fraction=0.15, p=0.9 ) ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 4bd85ca2e..dccf9b99b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.2_fix", + default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2af1b3af981d9ede788f0a16d6032dc4d55a6ed9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Feb 2022 19:39:19 +0800 Subject: [PATCH 013/234] Remove ReLU in attention --- .../ASR/transducer_stateless/conformer.py | 12 ++++++------ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index c06335905..4627dd147 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -629,7 +629,7 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).relu().chunk(3, dim=-1) + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -640,7 +640,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b).relu() + q = nn.functional.linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -649,7 +649,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1) + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) else: # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -659,7 +659,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b).relu() + q = nn.functional.linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias @@ -669,7 +669,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:_end, :] if _b is not None: _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b).relu() + k = nn.functional.linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias @@ -678,7 +678,7 @@ class RelPositionMultiheadAttention(nn.Module): _w = in_proj_weight[_start:, :] if _b is not None: _b = _b[_start:] - v = nn.functional.linear(value, _w, _b).relu() + v = nn.functional.linear(value, _w, _b) if attn_mask is not None: diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index dccf9b99b..7d1d7ff08 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -109,7 +109,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-relu-specaugmod_p0.9_0.15_fix", + default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 581786a6d367e7d9313c43ae12030bc6044c9d0c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 13:44:43 +0800 Subject: [PATCH 014/234] Adding diagnostics code... --- .../ASR/transducer_stateless/diagnostics.py | 284 ++++++++++++++++++ .../ASR/transducer_stateless/train.py | 40 ++- 2 files changed, 313 insertions(+), 11 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_stateless/diagnostics.py diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py new file mode 100644 index 000000000..2dff91805 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -0,0 +1,284 @@ +import torch +from torch import Tensor +from torch import nn +import math +import random +from typing import Tuple, List + + +class TensorDiagnosticOptions(object): + """ + Options object for tensor diagnostics: + + Args: + memory_limit: the maximum number of bytes per tensor (limits how many copies + of the tensor we cache). + + """ + def __init__(self, memory_limit: int, + print_pos_ratio: bool = True): + self.memory_limit = memory_limit + self.print_pos_ratio = print_pos_ratio + + def dim_is_summarized(self, size: int): + return size > 10 and size != 31 + + def stats_types(self): + if self.print_pos_ratio: + return ["mean-abs", "pos-ratio"] + else: + return ["mean-abs"] + + + +def get_sum_abs_stats(x: Tensor, dim: int, + stats_type: str) -> Tuple[Tensor, int]: + """ + Returns the sum-of-absolute-value of this Tensor, for each + index into the specified axis/dim of the tensor. + Args: + x: Tensor, tensor to be analyzed + dim: dimension with 0 <= dim < x.ndim + stats_type: either "mean-abs" in which case the stats represent the + mean absolute value, or "pos-ratio" in which case the + stats represent the proportion of positive values (actually: + the tensor is count of positive values, count is the count of + all values). + Returns (sum_abs, count) + where sum_abs is a Tensor of shape (x.shape[dim],), and the count + is an integer saying how many items were counted in each element + of sum_abs. + """ + if stats_type == "mean-abs": + x = x.abs() + else: + assert stats_type == "pos-ratio" + x = (x > 0).to(dtype=torch.float) + orig_numel = x.numel() + sum_dims = [ d for d in range(x.ndim) if d != dim ] + x = torch.sum(x, dim=sum_dims) + count = orig_numel // x.numel() + x = x.flatten() + return x, count + +def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], + options: TensorDiagnosticOptions, + sizes_same: bool, + stats_type: str): + """ + This function gets diagnostics for a dimension of a module. + Args: + dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim + options: options object + sizes_same: true if all the tensor sizes are the same on this dimension + stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats + we accumulate, mean-abs is mean absolute value, "pos-ratio" + is proportion of positive to nonnegative values. + Returns: + Diagnostic as a string, either percentiles or the actual values, + see the code. + """ + # stats_and_counts is a list of pair (Tensor, int) + stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] + stats = [ x[0] for x in stats_and_counts ] + counts = [ x[1] for x in stats_and_counts ] + if sizes_same: + stats = torch.stack(stats).sum(dim=0) + count = sum(counts) + stats = stats / count + else: + stats = [ x[0] / x[1] for x in stats_and_counts ] + stats = torch.cat(stats, dim=0) + # if `summarize` we print percentiles of the stats; else, + # we print out individual elements. + summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) + if summarize: + # print out percentiles. + stats = stats.sort()[0] + num_percentiles = 10 + size = stats.numel() + percentiles = [] + for i in range(num_percentiles + 1): + index = (i * (size - 1)) // num_percentiles + percentiles.append(stats[index].item()) + percentiles = [ '%.2g' % x for x in percentiles ] + percentiles = ' '.join(percentiles) + return f'percentiles: [{percentiles}]' + else: + stats = stats.tolist() + stats = [ '%.2g' % x for x in stats ] + stats = '[' + ' '.join(stats) + ']' + return stats + + + +def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], + options: TensorDiagnosticOptions): + + for stats_type in options.stats_types(): + # stats_type will be "mean-abs" or "pos-ratio". + sizes = [ x.shape[dim] for x in tensors ] + sizes_same = all([ x == sizes[0] for x in sizes ]) + s = get_diagnostics_for_dim(dim, tensors, + options, sizes_same, + stats_type) + + min_size = min(sizes) + max_size = max(sizes) + size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" + # stats_type will be "mean-abs" or "pos-ratio". + print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") + + +class TensorDiagnostic(object): + """ + This class is not directly used by the user, it is responsible for collecting + diagnostics for a single parameter tensor of a torch.Module. + """ + def __init__(self, + opts: TensorDiagnosticOptions, + name: str): + self.name = name + self.opts = opts + self.saved_tensors = [] + + def accumulate(self, x): + if isinstance(x, Tuple): + x = x[0] + if not isinstance(x, Tensor): + return + if x.device == torch.device('cpu'): + x = x.detach().clone() + else: + x = x.detach().to('cpu', non_blocking=True) + self.saved_tensors.append(x) + l = len(self.saved_tensors) + if l & (l - 1) == 0: # power of 2.. + self._limit_memory() + + def _limit_memory(self): + if len(self.saved_tensors) > 1024: + self.saved_tensors = self.saved_tensors[-1024:] + return + + tot_mem = 0.0 + for i in reversed(range(len(self.saved_tensors))): + tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() + if tot_mem > self.opts.memory_limit: + self.saved_tensors = self.saved_tensors[i:] + return + + def print_diagnostics(self): + if len(self.saved_tensors) == 0: + print("{name}: no stats".format(name=self.name)) + return + if self.saved_tensors[0].ndim == 0: + # ensure there is at least one dim. + self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] + + ndim = self.saved_tensors[0].ndim + for dim in range(ndim): + print_diagnostics_for_dim(self.name, dim, + self.saved_tensors, + self.opts) + + +class ModelDiagnostic(object): + def __init__(self, opts: TensorDiagnosticOptions): + self.diagnostics = dict() + self.opts = opts + + def __getitem__(self, name: str): + if name not in self.diagnostics: + self.diagnostics[name] = TensorDiagnostic(self.opts, name) + return self.diagnostics[name] + + def print_diagnostics(self): + for k in sorted(self.diagnostics.keys()): + self.diagnostics[k].print_diagnostics() + + + +def attach_diagnostics(model: nn.Module, + opts: TensorDiagnosticOptions) -> ModelDiagnostic: + ans = ModelDiagnostic(opts) + for name, module in model.named_modules(): + if name == '': + name = "" + forward_diagnostic = TensorDiagnostic(opts, name + ".output") + backward_diagnostic = TensorDiagnostic(opts, name + ".grad") + + + # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, + # ensures that we use the current values. (matters for name, since + # the variable gets overwritten). these closures don't really capture + # by value, only by "the final value the variable got in the function" :-( + def forward_hook(_module, _input, _output, + _model_diagnostic=ans, _name=name): + if isinstance(_output, Tensor): + _model_diagnostic[f"{_name}.output"].accumulate(_output) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) + + def backward_hook(_module, _input, _output, + _model_diagnostic=ans, _name=name): + if isinstance(_output, Tensor): + _model_diagnostic[f"{_name}.grad"].accumulate(_output) + elif isinstance(_output, tuple): + for i, o in enumerate(_output): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) + + module.register_forward_hook(forward_hook) + module.register_backward_hook(backward_hook) + + for name, parameter in model.named_parameters(): + + def param_backward_hook(grad, + _parameter=parameter, + _model_diagnostic=ans, + _name=name): + _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) + _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) + + parameter.register_hook(param_backward_hook) + return ans + + + +def _test_tensor_diagnostic(): + opts = TensorDiagnosticOptions(2**20, True) + + diagnostic = TensorDiagnostic(opts, "foo") + + for _ in range(10): + diagnostic.accumulate(torch.randn(50, 100) * 10.0) + + diagnostic.print_diagnostics() + + model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) + + diagnostic = attach_diagnostics(model, opts) + for _ in range(10): + T = random.randint(200, 300) + x = torch.randn(T, 100) + y = model(x) + y.sum().backward() + + diagnostic.print_diagnostics() + + + +if __name__ == '__main__': + _test_tensor_diagnostic() + + +def _test_func(): + ans = [] + for i in range(10): + x = list() + x.append(i) + def func(): + return x + ans.append(func) + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 7d1d7ff08..0e1bbeaff 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import diagnostics # ./diagnostics.py from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -109,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp-100h-specaugmod_p0.9_0.15_fix", + default="transducer_stateless/specaugmod_baseline", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -138,6 +139,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + return parser @@ -487,6 +495,9 @@ def train_one_epoch( loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if params.print_diagnostics and batch_idx == 5: + return + if batch_idx % params.log_interval == 0: logging.info( @@ -494,9 +505,6 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}" ) - - if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -599,6 +607,11 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions(2**22) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: train_cuts += librispeech.train_clean_360_cuts() @@ -626,13 +639,14 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) @@ -660,6 +674,10 @@ def run(rank, world_size, args): world_size=world_size, ) + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + save_checkpoint( params=params, model=model, From 63d8d935d43b719a74bdaa5db3892e71a2b9fe69 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 13:56:15 +0800 Subject: [PATCH 015/234] Refactor/simplify ConformerEncoder --- .../ASR/transducer_stateless/conformer.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 4627dd147..07b80076d 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import copy import math import warnings from typing import Optional, Tuple @@ -264,13 +264,12 @@ class ConformerEncoderLayer(nn.Module): return src -class ConformerEncoder(nn.TransformerEncoder): +class ConformerEncoder(nn.Module): r"""ConformerEncoder is a stack of N encoder layers Args: encoder_layer: an instance of the ConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -281,11 +280,12 @@ class ConformerEncoder(nn.TransformerEncoder): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + self, encoder_layer: nn.Module, num_layers: int ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) + super(ConformerEncoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.num_layers = num_layers + def forward( self, @@ -320,9 +320,6 @@ class ConformerEncoder(nn.TransformerEncoder): src_key_padding_mask=src_key_padding_mask, ) - if self.norm is not None: - output = self.norm(output) - return output From c1063def9552fd3af9a6d54b304a9cc6939a8b93 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Feb 2022 17:34:58 +0800 Subject: [PATCH 016/234] First version of rand-combine iterated-training-like idea. --- .../ASR/transducer_stateless/conformer.py | 224 +++++++++++++++++- .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 219 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 07b80076d..327849485 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Sequence import torch from torch import Tensor, nn @@ -56,6 +56,7 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, + aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -80,10 +81,11 @@ class Conformer(Transformer): cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, + aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.normalize_before = normalize_before if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) + self.after_norm = nn.LayerNorm(d_model) # TODO: remove. else: # Note: TorchScript detects that self.after_norm could be used inside forward() # and throws an error without this change. @@ -280,12 +282,21 @@ class ConformerEncoder(nn.Module): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int + self, encoder_layer: nn.Module, + num_layers: int, + aux_layers: Sequence[int], ) -> None: super(ConformerEncoder, self).__init__() self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.aux_layers = set(aux_layers + [num_layers - 1]) + assert num_layers - 1 not in aux_layers self.num_layers = num_layers - + num_channels = encoder_layer.norm_final.weight.numel() + self.combiner = RandomCombine(num_inputs=len(self.aux_layers), + num_channels=num_channels, + final_weight=0.5, + pure_prob=0.333, + stddev=2.0) def forward( self, @@ -312,14 +323,19 @@ class ConformerEncoder(nn.Module): """ output = src - for mod in self.layers: + outputs = [] + + for i, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) + if i in self.aux_layers: + outputs.append(output) + output = self.combiner(outputs) return output @@ -918,7 +934,203 @@ def identity(x): return x +class RandomCombine(torch.nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + All but the last input will have a linear transform before we + randomly combine them; these linear transforms will be initialzed + to the identity transform. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + def __init__(self, num_inputs: int, + num_channels: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0) -> None: + """ + Args: + num_inputs: The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + num_channels: The number of channels on the input, e.g. 512. + final_weight: The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, + or combinations of layers, to use, is conceptually as follows. + With probability `pure_prob`: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super(RandomCombine, self).__init__() + assert pure_prob >= 0 and pure_prob <= 1 + assert final_weight > 0 and final_weight < 1 + assert num_inputs >= 1 + self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True) + for _ in range(num_inputs - 1)]) + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev= stddev + + self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() + self._reset_parameters() + + def _reset_parameters(self): + for i in range(len(self.linear)): + nn.init.eye_(self.linear[i].weight) + nn.init.constant_(self.linear[i].bias, 0.0) + + def forward(self, inputs: Sequence[Tensor]) -> Tensor: + """ + Forward function. + Args: + inputs: a list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + a Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training: + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + mod_inputs = [] + for i in range(num_inputs - 1): + mod_inputs.append(self.linear[i](inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) + + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, + num_frames) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + + def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: + """ + Return a tensor of random weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), such that + ans.sum(dim=1) is all ones. + + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) + + def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with + exactly one weight equal to 1.0 on each frame. + """ + + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + + indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, + final, nonfinal) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) + return ans + + + def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that + sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. + """ + logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev + logprobs[:,-1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") + num_inputs = 3 + num_channels = 50 + m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, + final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) + + x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + if __name__ == '__main__': + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 0e1bbeaff..8877d4e75 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline", + default="transducer_stateless/specaugmod_baseline_randcombine1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2ff520c8004eb7cfd43585ae840ca0fcb5bbcfae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 28 Feb 2022 12:15:56 +0800 Subject: [PATCH 017/234] Improvements to diagnostics (RE those with 1 dim --- .../ASR/transducer_stateless/diagnostics.py | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 2dff91805..088ef14cb 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -31,32 +31,34 @@ class TensorDiagnosticOptions(object): -def get_sum_abs_stats(x: Tensor, dim: int, +def get_tensor_stats(x: Tensor, dim: int, stats_type: str) -> Tuple[Tensor, int]: """ - Returns the sum-of-absolute-value of this Tensor, for each - index into the specified axis/dim of the tensor. + Returns the specified transformation of the Tensor (either x or x.abs() + or (x > 0), summed over all but the index `dim`. + Args: x: Tensor, tensor to be analyzed dim: dimension with 0 <= dim < x.ndim - stats_type: either "mean-abs" in which case the stats represent the - mean absolute value, or "pos-ratio" in which case the - stats represent the proportion of positive values (actually: - the tensor is count of positive values, count is the count of - all values). - Returns (sum_abs, count) - where sum_abs is a Tensor of shape (x.shape[dim],), and the count + stats_type: + "mean-abs" or "abs-value" -> take abs() before summing + "pos-ratio" -> take (x > 0) before summing + "value -> just sum x itself + Returns (stats, count) + where stats is a Tensor of shape (x.shape[dim],), and the count is an integer saying how many items were counted in each element - of sum_abs. + of stats. """ - if stats_type == "mean-abs": + if stats_type == "mean-abs" or stats_type == "abs-value": x = x.abs() - else: - assert stats_type == "pos-ratio" + elif stats_type == "pos-ratio": x = (x > 0).to(dtype=torch.float) + else: + assert stats_type == "value" orig_numel = x.numel() sum_dims = [ d for d in range(x.ndim) if d != dim ] - x = torch.sum(x, dim=sum_dims) + if len(sum_dims) > 0: + x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() return x, count @@ -79,7 +81,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], see the code. """ # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_sum_abs_stats(x, dim, stats_type) for x in tensors ] + stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] if sizes_same: @@ -114,9 +116,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions): + ndim = tensors[0].ndim + # options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the + # normal case. + stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] - for stats_type in options.stats_types(): - # stats_type will be "mean-abs" or "pos-ratio". + for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] sizes_same = all([ x == sizes[0] for x in sizes ]) s = get_diagnostics_for_dim(dim, tensors, From 9d1b4ae04682d12aef2fecb902f318fcf9cab716 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 16:33:27 +0800 Subject: [PATCH 018/234] Add pelu to this good-performing setup.. --- .../ASR/conformer_ctc/subsampling.py | 38 ++++++++++++++++++- .../ASR/transducer_stateless/conformer.py | 17 +++------ .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 542fb0364..b23071926 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -45,13 +45,14 @@ class Conv2dSubsampling(nn.Module): nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - nn.ReLU(), + PeLU(cutoff=-1.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - nn.ReLU(), + PeLU(cutoff=-5.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -70,6 +71,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) return x @@ -159,3 +161,35 @@ class VggSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) return x + + +class PeLUFunction(torch.autograd.Function): + """ + Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)). + The function is: + x.relu() + alpha * (cutoff - x).relu() + E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off + of neurons. + """ + @staticmethod + def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor: + mask1 = (x >= 0) # >=, so there is deriv if x == 0. + p = cutoff - x + mask2 = (p >= 0) + ctx.save_for_backward(mask1, mask2) + ctx.alpha = alpha + return x.relu() + alpha * p.relu() + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]: + mask1, mask2 = ctx.saved_tensors + return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None + + + +class PeLU(torch.nn.Module): + def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None: + super(PeLU, self).__init__() + self.cutoff = cutoff + self.alpha = alpha + def forward(self, x: Tensor) -> Tensor: + return PeLUFunction.apply(x, self.cutoff, self.alpha) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 327849485..066232a02 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,6 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence +from subsampling import PeLU import torch from torch import Tensor, nn @@ -84,12 +85,7 @@ class Conformer(Transformer): self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) # TODO: remove. - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity + def forward( self, x: torch.Tensor, x_lens: torch.Tensor @@ -118,9 +114,6 @@ class Conformer(Transformer): x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) - if self.normalize_before: - x = self.after_norm(x) - logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -163,14 +156,14 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), + PeLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), + PeLU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -889,7 +882,7 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - self.activation = Swish() + self.activation = PeLU() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 8877d4e75..88b366245 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1", + default="transducer_stateless/specaugmod_baseline_randcombine1_pelu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 9ed7d55a846047373a24f0c084bf7f325e9cbe95 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 16:34:55 +0800 Subject: [PATCH 019/234] Small bug fixes/imports --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index b23071926..c97f1ef48 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from torch import Tensor +from typing import Tuple class Conv2dSubsampling(nn.Module): From 3fb559d2f02402af91707fe0df633bcca497fc4d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Mar 2022 18:27:08 +0800 Subject: [PATCH 020/234] Add baseline for the PeLU expt, keeping only the small normalization-related changes. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index c97f1ef48..0e5e2d3de 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,11 +47,11 @@ class Conv2dSubsampling(nn.Module): nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - PeLU(cutoff=-1.0), + nn.ReLU(), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - PeLU(cutoff=-5.0), + nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 066232a02..2b97047cf 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,14 +156,14 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - PeLU(), + Swish(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - PeLU(), + Swish(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -882,7 +882,7 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - self.activation = PeLU() + self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 88b366245..283aaecdd 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_pelu", + default="transducer_stateless/specaugmod_baseline_randcombine1_pelu_base", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5c177fc52b551c188bbb828cad1d13450553aca6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 3 Mar 2022 23:52:03 +0800 Subject: [PATCH 021/234] pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 11 +++++++++++ egs/librispeech/ASR/transducer_stateless/conformer.py | 4 +++- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 0e5e2d3de..73493a7ea 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,10 +48,12 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), + ExpScale(odim, 1, 1, speed=2.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), + ExpScale(odim, 1, 1, speed=2.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) @@ -195,3 +197,12 @@ class PeLU(torch.nn.Module): self.alpha = alpha def forward(self, x: Tensor) -> Tensor: return PeLUFunction.apply(x, self.cutoff, self.alpha) + +class ExpScale(torch.nn.Module): + def __init__(self, *shape, speed: float = 1.0): + super(ExpScale, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return x * (self.scale * self.speed).exp() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 2b97047cf..3789e02fd 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU +from subsampling import PeLU, ExpScale import torch from torch import Tensor, nn @@ -157,6 +157,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), + ExpScale(dim_feedforward, speed=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -164,6 +165,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), + ExpScale(dim_feedforward, speed=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 283aaecdd..183a924c6 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_pelu_base", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 23b3aa233c792de86fb23b04f4a0160ba74f4d51 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 00:42:37 +0800 Subject: [PATCH 022/234] Double learning rate of exp-scale units --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 73493a7ea..3b35c2ebe 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=2.0), + ExpScale(odim, 1, 1, speed=4.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=2.0), + ExpScale(odim, 1, 1, speed=4.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3789e02fd..59f317e90 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -157,7 +157,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=2.0), + ExpScale(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -165,7 +165,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=2.0), + ExpScale(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 183a924c6..a1ded87c6 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From bc6c720e257c0b586ca2257d5be14b5358012bc1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 10:52:05 +0800 Subject: [PATCH 023/234] Combine ExpScale and swish for memory reduction --- .../ASR/conformer_ctc/subsampling.py | 67 +++++++++++++++++++ .../ASR/transducer_stateless/conformer.py | 5 +- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 3b35c2ebe..600156bf1 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -206,3 +206,70 @@ class ExpScale(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return x * (self.scale * self.speed).exp() + + + +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * torch.sigmoid(x)) * (scale * speed).exp() + + +def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * torch.sigmoid(x)) * (scale * speed).exp() + + +class ExpScaleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x, scale) + ctx.speed = speed + return _exp_scale_swish(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_swish(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + + +class ExpScaleSwish(torch.nn.Module): + # combines ExpScale an Swish + # caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) + def __init__(self, *shape, speed: float = 1.0): + super(ExpScaleSwish, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return ExpScaleSwishFunction.apply(x, self.scale, self.speed) + # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() + # return x * (self.scale * self.speed).exp() + +def _test_exp_scale_swish(): + class Swish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + x1 = torch.randn(50, 60).detach() + x2 = x1.detach() + + m1 = ExpScaleSwish(50, 1, speed=4.0) + m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0)) + x1.requires_grad = True + x2.requires_grad = True + + y1 = m1(x1) + y2 = m2(x2) + assert torch.allclose(y1, y2) + y1.sum().backward() + y2.sum().backward() + assert torch.allclose(x1.grad, x2.grad) + + + +if __name__ == '__main__': + _test_exp_scale_swish() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 59f317e90..3386ed9b2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,8 +156,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), - ExpScale(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -165,7 +164,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScale(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=4.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) From cd216f50b63e92f8bdce493428b553570615ead1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 11:03:01 +0800 Subject: [PATCH 024/234] Add import --- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3386ed9b2..83e0f8bca 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale +from subsampling import PeLU, ExpScale, ExpScaleSwish import torch from torch import Tensor, nn From 3d9ddc201680747cab89838d9abe9797225f0128 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 12:29:44 +0800 Subject: [PATCH 025/234] Fix backprop bug --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 600156bf1..a66421adf 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -220,7 +220,7 @@ def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: class ExpScaleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x, scale) + ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed return _exp_scale_swish(x, scale, speed) From 503f8d521ce10d24e5fc1b62760c630843cf80c2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 13:08:56 +0800 Subject: [PATCH 026/234] Fix bug in diagnostics --- egs/librispeech/ASR/transducer_stateless/diagnostics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 2dff91805..1a2324775 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -56,7 +56,8 @@ def get_sum_abs_stats(x: Tensor, dim: int, x = (x > 0).to(dtype=torch.float) orig_numel = x.numel() sum_dims = [ d for d in range(x.ndim) if d != dim ] - x = torch.sum(x, dim=sum_dims) + if len(sum_dims) != 0: + x = torch.sum(x, dim=sum_dims) count = orig_numel // x.numel() x = x.flatten() return x, count From 3207bd98a942f96e4a052d9984ea8ee0040f2269 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 13:16:40 +0800 Subject: [PATCH 027/234] Increase scale on Scale from 4 to 20 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index a66421adf..e38a94d09 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=4.0), + ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=4.0), + ExpScale(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 83e0f8bca..6907feb26 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,7 +156,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -164,7 +164,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScaleSwish(dim_feedforward, speed=4.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a1ded87c6..c57968428 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 7e889996413bc757c2cc7160c6e96644467ab57e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 14:31:29 +0800 Subject: [PATCH 028/234] Increase scale from 20 to 50. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index e38a94d09..97b9ae97b 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScale(odim, 1, 1, speed=50.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScale(odim, 1, 1, speed=50.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 6907feb26..ef6b4ac97 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,7 +156,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=50.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -164,7 +164,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=50.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c57968428..980633ed6 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale4", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 9cc5999829ef1441d99804204d1f61a796bc4948 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 15:50:51 +0800 Subject: [PATCH 029/234] Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module --- egs/librispeech/ASR/transducer_stateless/conformer.py | 9 +++------ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index ef6b4ac97..dc6b54399 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -163,7 +163,6 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - Swish(), ExpScaleSwish(dim_feedforward, speed=50.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -874,7 +873,9 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.LayerNorm(channels) + # shape: (channels, 1), broadcasts with (batch, channel, time). + self.activation = ExpScaleSwish(channels, 1, speed=50.0) + self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -883,7 +884,6 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. @@ -905,9 +905,6 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) # x is (batch, channels, time) - x = x.permute(0, 2, 1) - x = self.norm(x) - x = x.permute(0, 2, 1) x = self.activation(x) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 980633ed6..973733d4b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale4", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From eb3ed5420249d1c70d48700d72837e1c8a646454 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 15:56:45 +0800 Subject: [PATCH 030/234] Reduce scale from 50 to 20 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 97b9ae97b..e38a94d09 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,12 +48,12 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=50.0), + ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), nn.ReLU(), - ExpScale(odim, 1, 1, speed=50.0), + ExpScale(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index dc6b54399..368165008 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,14 +156,14 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=50.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - ExpScaleSwish(dim_feedforward, speed=50.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -874,7 +874,7 @@ class ConvolutionModule(nn.Module): bias=bias, ) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleSwish(channels, 1, speed=50.0) + self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, From 6252282fd02f0a105e718091dd321cc71e205a95 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 4 Mar 2022 20:19:11 +0800 Subject: [PATCH 031/234] Add deriv-balancing code --- .../ASR/conformer_ctc/subsampling.py | 87 +++++++++++++++++++ .../ASR/transducer_stateless/conformer.py | 6 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index e38a94d09..aa842a31f 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,11 +47,15 @@ class Conv2dSubsampling(nn.Module): nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), + DerivBalancer(channel_dim=1, threshold=0.02, + max_factor=0.02), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), + DerivBalancer(channel_dim=1, threshold=0.02, + max_factor=0.02), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), ) @@ -248,6 +252,68 @@ class ExpScaleSwish(torch.nn.Module): # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() # return x * (self.scale * self.speed).exp() + + + +class DerivBalancerFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, channel_dim: int, + threshold: 0.05, max_factor: 0.05, + epsilon: 1.0e-10) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) + factor = (threshold - proportion_positive).relu() * (max_factor / threshold) + + ctx.save_for_backward(factor) + ctx.epsilon = epsilon + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + factor, = ctx.saved_tensors + neg_delta_grad = x_grad.abs() * factor + if ctx.epsilon != 0.0: + sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True) + deriv_is_zero = (sum_abs_grad == 0.0) + neg_delta_grad += ctx.epsilon * deriv_is_zero + + return x_grad - neg_delta_grad, None, None, None, None + + + +class DerivBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 0 at the threshold to those extremal values when none + of the inputs are positive. + + When all grads are zero for a channel, this + module sets all the input derivatives for that channel to -epsilon; the + idea is to bring completely dead neurons back to life this way. + """ + def __init__(self, channel_dim: int, + threshold: float = 0.05, + max_factor: float = 0.05, + epsilon: float = 1.0e-10): + super(DerivBalancer, self).__init__() + self.channel_dim = channel_dim + self.threshold = threshold + self.max_factor = max_factor + self.epsilon = epsilon + + def forward(self, x: Tensor) -> Tensor: + return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, + self.max_factor, self.epsilon) + + + def _test_exp_scale_swish(): class Swish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -271,5 +337,26 @@ def _test_exp_scale_swish(): +def _test_deriv_balancer(): + channel_dim = 0 + probs = torch.arange(0, 1, 0.01) + N = 500 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + y_grad[-1,:] = 0 + + y = m(x) + y.backward(gradient=y_grad) + print("x = ", x) + print("y grad = ", y_grad) + print("x grad = ", x.grad) + + + if __name__ == '__main__': + _test_deriv_balancer() _test_exp_scale_swish() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 368165008..056958ff6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish +from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer import torch from torch import Tensor, nn @@ -156,6 +156,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), + DerivBalancer(channel_dim=-1, threshold=0.02, + max_factor=0.02), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -163,6 +165,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), + DerivBalancer(channel_dim=-1, threshold=0.02, + max_factor=0.02), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 973733d4b..6d6b3f240 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 65b09dd5f22f72923289fd68c1641ecd33fa0c52 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 00:07:14 +0800 Subject: [PATCH 032/234] Double the threshold in brelu; slightly increase max_factor. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 8 ++++---- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index aa842a31f..ba0f08271 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,15 +47,15 @@ class Conv2dSubsampling(nn.Module): nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=1, threshold=0.05, + max_factor=0.025), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=1, threshold=0.05, + max_factor=0.025), nn.ReLU(), ExpScale(odim, 1, 1, speed=20.0), ) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 056958ff6..42d159ff5 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -156,8 +156,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=-1, threshold=0.05, + max_factor=0.025), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -165,8 +165,8 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.02, - max_factor=0.02), + DerivBalancer(channel_dim=-1, threshold=0.05, + max_factor=0.025), ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), From 0cd14ae739ecfe9f01ebccf5f4b18cc7b9cbc8c0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 12:17:09 +0800 Subject: [PATCH 033/234] Fix exp dir --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 6d6b3f240..eed89e6b9 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5f2c0a09b7eede63054ef20627cf298e2223734d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 16:28:24 +0800 Subject: [PATCH 034/234] Convert swish nonlinearities to ReLU --- .../ASR/conformer_ctc/subsampling.py | 78 ++++++++++++++++++- .../ASR/transducer_stateless/conformer.py | 11 ++- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 82 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ba0f08271..a500e42a9 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -49,15 +49,13 @@ class Conv2dSubsampling(nn.Module): ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025), - nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScaleRelu(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025), - nn.ReLU(), - ExpScale(odim, 1, 1, speed=20.0), + ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) @@ -253,6 +251,60 @@ class ExpScaleSwish(torch.nn.Module): # return x * (self.scale * self.speed).exp() +def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: + return (x * (scale * speed).exp()).relu() + + +class ExpScaleReluFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x.detach(), scale.detach()) + ctx.speed = speed + return _exp_scale_swish(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_swish(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + + + +class ExpScaleReluFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + ctx.save_for_backward(x.detach(), scale.detach()) + ctx.speed = speed + return _exp_scale_relu(x, scale, speed) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + x, scale = ctx.saved_tensors + x.requires_grad = True + scale.requires_grad = True + with torch.enable_grad(): + y = _exp_scale_relu(x, scale, ctx.speed) + y.backward(gradient=y_grad) + return x.grad, scale.grad, None + +class ExpScaleRelu(torch.nn.Module): + # combines ExpScale and Relu. + # caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0) + def __init__(self, *shape, speed: float = 1.0): + super(ExpScaleRelu, self).__init__() + self.scale = nn.Parameter(torch.zeros(*shape)) + self.speed = speed + + def forward(self, x: Tensor) -> Tensor: + return ExpScaleReluFunction.apply(x, self.scale, self.speed) + # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() + # return x * (self.scale * self.speed).exp() + + class DerivBalancerFunction(torch.autograd.Function): @@ -335,6 +387,23 @@ def _test_exp_scale_swish(): y2.sum().backward() assert torch.allclose(x1.grad, x2.grad) +def _test_exp_scale_relu(): + + x1 = torch.randn(50, 60).detach() + x2 = x1.detach() + + m1 = ExpScaleRelu(50, 1, speed=4.0) + m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0)) + x1.requires_grad = True + x2.requires_grad = True + + y1 = m1(x1) + y2 = m2(x2) + assert torch.allclose(y1, y2) + y1.sum().backward() + y2.sum().backward() + assert torch.allclose(x1.grad, x2.grad) + def _test_deriv_balancer(): @@ -360,3 +429,4 @@ def _test_deriv_balancer(): if __name__ == '__main__': _test_deriv_balancer() _test_exp_scale_swish() + _test_exp_scale_relu() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 42d159ff5..7af145a1e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, DerivBalancer +from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer import torch from torch import Tensor, nn @@ -158,7 +158,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleRelu(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + ExpScaleRelu(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -877,8 +877,10 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) + self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, + max_factor=0.025) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleSwish(channels, 1, speed=20.0) + self.activation = ExpScaleRelu(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, @@ -910,6 +912,7 @@ class ConvolutionModule(nn.Module): x = self.depthwise_conv(x) # x is (batch, channels, time) + x = self.balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index eed89e6b9..b1cb6d043 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 8a8b81cd181e209b7609d7e8d54467bfbe758271 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 5 Mar 2022 22:21:42 +0800 Subject: [PATCH 035/234] Replace relu with swish-squared. --- .../ASR/conformer_ctc/subsampling.py | 18 ++++++++++-------- .../ASR/transducer_stateless/conformer.py | 6 +++--- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index a500e42a9..daf8fd251 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -212,12 +212,11 @@ class ExpScale(torch.nn.Module): def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * torch.sigmoid(x)) * (scale * speed).exp() - - -def _exp_scale_swish_backward(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * torch.sigmoid(x)) * (scale * speed).exp() - + # double-swish! + x = (x * torch.sigmoid(x)) + x = (x * torch.sigmoid(x)) + x = x * (scale * speed).exp() + return x class ExpScaleSwishFunction(torch.autograd.Function): @staticmethod @@ -247,8 +246,11 @@ class ExpScaleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return ExpScaleSwishFunction.apply(x, self.scale, self.speed) - # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() - # return x * (self.scale * self.speed).exp() + # x = (x * torch.sigmoid(x)) + # x = (x * torch.sigmoid(x)) + # x = x * (self.scale * self.speed).exp() + # return x + def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7af145a1e..5adb7ca4e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -158,7 +158,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleRelu(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleRelu(dim_feedforward, speed=20.0), + ExpScaleSwish(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -880,7 +880,7 @@ class ConvolutionModule(nn.Module): self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.025) # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleRelu(channels, 1, speed=20.0) + self.activation = ExpScaleSwish(channels, 1, speed=20.0) self.pointwise_conv2 = nn.Conv1d( channels, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b1cb6d043..a3eca26c9 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2relu", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a37d98463aeaf0fd9370128cd0f03663bb3aaab1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 6 Mar 2022 11:55:02 +0800 Subject: [PATCH 036/234] Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. --- .../ASR/conformer_ctc/subsampling.py | 5 ++--- .../ASR/transducer_stateless/conformer.py | 22 ++++++++++++++----- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index daf8fd251..1fe1265fa 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -212,9 +212,8 @@ class ExpScale(torch.nn.Module): def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - # double-swish! - x = (x * torch.sigmoid(x)) - x = (x * torch.sigmoid(x)) + # double-swish, implemented/approximated as offset-swish + x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 5adb7ca4e..62d9f382f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -877,10 +877,10 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.balancer = DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.025) - # shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = ExpScaleSwish(channels, 1, speed=20.0) + + self.norm = nn.LayerNorm(channels) + # shape: (channels, 1), broadcasts with (batch, channel, time). + self.activation = SwishOffset() self.pointwise_conv2 = nn.Conv1d( channels, @@ -911,8 +911,10 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) - x = self.balancer(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) @@ -927,6 +929,16 @@ class Swish(torch.nn.Module): """Return Swich activation function.""" return x * torch.sigmoid(x) +class SwishOffset(torch.nn.Module): + """Construct an SwishOffset object.""" + def __init__(self, offset: float = -1.0) -> None: + super(SwishOffset, self).__init__() + self.offset = offset + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x + self.offset) + def identity(x): return x diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a3eca26c9..16746147f 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale5_brelu2swish2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From e2ace9d5457139dbc5a8092c9cc6afffab633857 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 7 Mar 2022 11:24:04 +0800 Subject: [PATCH 037/234] Replace norm on input layer with scale of 0.1. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 3 +-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 1fe1265fa..2df2678dd 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -58,7 +58,6 @@ class Conv2dSubsampling(nn.Module): ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - self.out_norm = nn.LayerNorm(odim, elementwise_affine=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -77,7 +76,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) + x = x * 0.1 return x diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 16746147f..0dbd8479b 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From d074cf73c6ba428f3667ffede22a336febb72fb1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 9 Mar 2022 20:37:20 +0800 Subject: [PATCH 038/234] Extensions to diagnostics code --- .../ASR/transducer_stateless/diagnostics.py | 52 +++++++++++++++---- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 088ef14cb..dfbc2dced 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -25,7 +25,7 @@ class TensorDiagnosticOptions(object): def stats_types(self): if self.print_pos_ratio: - return ["mean-abs", "pos-ratio"] + return ["mean-abs", "pos-ratio", "value"] else: return ["mean-abs"] @@ -49,17 +49,23 @@ def get_tensor_stats(x: Tensor, dim: int, is an integer saying how many items were counted in each element of stats. """ - if stats_type == "mean-abs" or stats_type == "abs-value": + count = x.numel() // x.shape[dim] + + if stats_type == "eigs": + x = x.transpose(dim, -1) + x = x.reshape(-1, x.shape[-1]) + # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. + return torch.matmul(x.transpose(0, 1), x), count + elif stats_type == "mean-abs" or stats_type == "abs-value": x = x.abs() elif stats_type == "pos-ratio": x = (x > 0).to(dtype=torch.float) else: assert stats_type == "value" - orig_numel = x.numel() + sum_dims = [ d for d in range(x.ndim) if d != dim ] if len(sum_dims) > 0: x = torch.sum(x, dim=sum_dims) - count = orig_numel // x.numel() x = x.flatten() return x, count @@ -73,18 +79,35 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim options: options object sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "mean-abs" or "pos-ratio", dictates the type of stats + stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value, + imdictates the type of stats we accumulate, mean-abs is mean absolute value, "pos-ratio" - is proportion of positive to nonnegative values. + is proportion of positive to nonnegative values, "eigs" + is eigenvalues after doing outer product on this dim, sum + over all other dimes. Returns: Diagnostic as a string, either percentiles or the actual values, - see the code. + see the code. Will return the empty string if the diagnostics did + not make sense to print out for this dimension, e.g. dimension + mismatch and stats_type == "eigs" """ # stats_and_counts is a list of pair (Tensor, int) + if tensors[0].shape[dim] > 512 and stats_type == 'eigs': + return '' # won't produce eigs stats if dim too large. stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] - if sizes_same: + + if stats_type == 'eigs': + try: + stats = torch.stack(stats).sum(dim=0) + except: + return '' + count = sum(counts) + stats = stats / count + stats, _ = torch.symeig(stats) + stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance + elif sizes_same: stats = torch.stack(stats).sum(dim=0) count = sum(counts) stats = stats / count @@ -121,12 +144,16 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], # normal case. stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] + stats_types = stats_types + ["eigs"] + for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] sizes_same = all([ x == sizes[0] for x in sizes ]) s = get_diagnostics_for_dim(dim, tensors, options, sizes_same, stats_type) + if s == '': + continue min_size = min(sizes) max_size = max(sizes) @@ -181,10 +208,17 @@ class TensorDiagnostic(object): # ensure there is at least one dim. self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] + try: + device = torch.device('cuda') + torch.ones(1, 1, device) + except: + device = torch.device('cpu') + ndim = self.saved_tensors[0].ndim + tensors = [x.to(device) for x in self.saved_tensors] for dim in range(ndim): print_diagnostics_for_dim(self.name, dim, - self.saved_tensors, + tensors, self.opts) From 1e5455ba2904efab594e68e16d548de32f104a14 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 10:28:48 +0800 Subject: [PATCH 039/234] Update diagnostics --- .../ASR/transducer_stateless/diagnostics.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index dfbc2dced..8ea35582a 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -11,24 +11,21 @@ class TensorDiagnosticOptions(object): Options object for tensor diagnostics: Args: - memory_limit: the maximum number of bytes per tensor (limits how many copies + memory_limit: the maximum number of bytes we store per tensor (limits how many copies of the tensor we cache). - + max_eig_dim: the maximum dimension for which we print out eigenvalues + (limited for speed reasons). """ - def __init__(self, memory_limit: int, - print_pos_ratio: bool = True): + def __init__(self, + memory_limit: int = (2 ** 20), + max_eig_dim: int = 512): + self.memory_limit = memory_limit - self.print_pos_ratio = print_pos_ratio + self.max_eig_dim = max_eig_dim def dim_is_summarized(self, size: int): return size > 10 and size != 31 - def stats_types(self): - if self.print_pos_ratio: - return ["mean-abs", "pos-ratio", "value"] - else: - return ["mean-abs"] - def get_tensor_stats(x: Tensor, dim: int, @@ -41,8 +38,9 @@ def get_tensor_stats(x: Tensor, dim: int, x: Tensor, tensor to be analyzed dim: dimension with 0 <= dim < x.ndim stats_type: - "mean-abs" or "abs-value" -> take abs() before summing - "pos-ratio" -> take (x > 0) before summing + "abs" -> take abs() before summing + "positive" -> take (x > 0) before summing + "rms" -> square before summing, we'll take sqrt later "value -> just sum x itself Returns (stats, count) where stats is a Tensor of shape (x.shape[dim],), and the count @@ -56,9 +54,11 @@ def get_tensor_stats(x: Tensor, dim: int, x = x.reshape(-1, x.shape[-1]) # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. return torch.matmul(x.transpose(0, 1), x), count - elif stats_type == "mean-abs" or stats_type == "abs-value": + elif stats_type == "abs": x = x.abs() - elif stats_type == "pos-ratio": + elif stats_type == "rms": + x = x ** 2 + elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: assert stats_type == "value" @@ -79,9 +79,9 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim options: options object sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "mean-abs" or "pos-ratio" or "eigs" or "value, + stats_type: either "abs" or "positive" or "eigs" or "value, imdictates the type of stats - we accumulate, mean-abs is mean absolute value, "pos-ratio" + we accumulate, abs is mean absolute value, "positive" is proportion of positive to nonnegative values, "eigs" is eigenvalues after doing outer product on this dim, sum over all other dimes. @@ -92,13 +92,11 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], mismatch and stats_type == "eigs" """ # stats_and_counts is a list of pair (Tensor, int) - if tensors[0].shape[dim] > 512 and stats_type == 'eigs': - return '' # won't produce eigs stats if dim too large. stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] stats = [ x[0] for x in stats_and_counts ] counts = [ x[1] for x in stats_and_counts ] - if stats_type == 'eigs': + if stats_type == "eigs": try: stats = torch.stack(stats).sum(dim=0) except: @@ -114,6 +112,9 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], else: stats = [ x[0] / x[1] for x in stats_and_counts ] stats = torch.cat(stats, dim=0) + if stats_type == 'rms': + stats = stats.sqrt() + # if `summarize` we print percentiles of the stats; else, # we print out individual elements. summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) @@ -140,11 +141,12 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], options: TensorDiagnosticOptions): ndim = tensors[0].ndim - # options.stats_types() should return [ "mean-abs", "pos-ratio" ] in the - # normal case. - stats_types = options.stats_types() if ndim > 1 else [ "value", "abs-value" ] - - stats_types = stats_types + ["eigs"] + if ndim > 1: + stats_types = ["abs", "positive", "value", "rms"] + if tensors[0].shape[dim] <= options.max_eig_dim: + stats_types.append("eigs") + else: + stats_types = [ "value", "abs" ] for stats_type in stats_types: sizes = [ x.shape[dim] for x in tensors ] @@ -158,7 +160,7 @@ def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], min_size = min(sizes) max_size = max(sizes) size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "mean-abs" or "pos-ratio". + # stats_type will be "abs" or "positive". print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") @@ -223,7 +225,7 @@ class TensorDiagnostic(object): class ModelDiagnostic(object): - def __init__(self, opts: TensorDiagnosticOptions): + def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): self.diagnostics = dict() self.opts = opts @@ -286,7 +288,7 @@ def attach_diagnostics(model: nn.Module, def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, True) + opts = TensorDiagnosticOptions(2**20, 512) diagnostic = TensorDiagnostic(opts, "foo") From 059b57ad37c98ba228a708821eafea0f3c152146 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 14:32:05 +0800 Subject: [PATCH 040/234] Add BasicNorm module --- .../ASR/conformer_ctc/subsampling.py | 101 +++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 2df2678dd..622495f21 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -336,6 +336,83 @@ class DerivBalancerFunction(torch.autograd.Function): return x_grad - neg_delta_grad, None, None, None, None +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. + + We also introduce a learned scaling factor on the output; and we + remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not + that useful unless the LayerNorm immediately follows a nonlinearity). + + + Args: + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + initial_eps_scale: a constant that determines the initial + "epsilon" that we add as ballast in: + scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5 + Note: our epsilon is actually large, not small, but we keep the name + to indicate the connection with normal LayerNorm. We set + epsilon initially to num_channels * initial_eps_scale. + speed: a scaling factor that can be interpreted as scaling the learning + rate for this module. CAUTION: the default value of 10.0 intended to be + used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. + If you are using SGD you would probably have to set `speed` to + a value less than one, or the training would be unstable. + """ + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + initial_eps_scale: float = 0.25, + speed: float = 10.0): + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.speed = speed + eps = num_channels * initial_eps_scale + # log_eps = log(eps) / speed + log_eps = torch.tensor(eps).log() / speed + self.log_eps = nn.Parameter(log_eps.detach()) + # initial output-scale, to get LayerNorm-like behavior, is + # sqrt(num_channels). + initial_scale = torch.tensor(num_channels ** 0.5).log() / speed + self.log_scale = nn.Parameter(initial_scale.detach()) + + def _inner(self, x: Tensor) -> Tensor: + # inner product on last dim of x, keeping the dimension, + # i.e. torch.sum(x**2, dim=-1, keepdim=True), but more + # efficient. + if hasattr(torch, 'inner'): + return torch.inner(x).unsqueeze(-1) + else: + # TODO: we can do this with matrix multiplication, maybe.a + return torch.sum(x**2, dim=-1, keepdim=True) + + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + x = x.transpose(-1, self.channel_dim) + eps = (self.log_eps * self.speed).exp() + out_scale = (self.log_scale * self.speed).exp() + + scales = out_scale * (self._inner(x) + eps) ** -0.5 + x = x * scales + x = x.transpose(-1, self.channel_dim) + return x + + class DerivBalancer(torch.nn.Module): """ @@ -367,16 +444,16 @@ class DerivBalancer(torch.nn.Module): def _test_exp_scale_swish(): - class Swish(torch.nn.Module): + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swich activation function.""" - return x * torch.sigmoid(x) + return x * torch.sigmoid(x - 1.0) x1 = torch.randn(50, 60).detach() x2 = x1.detach() m1 = ExpScaleSwish(50, 1, speed=4.0) - m2 = torch.nn.Sequential(Swish(), ExpScale(50, 1, speed=4.0)) + m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) x1.requires_grad = True x2.requires_grad = True @@ -425,8 +502,26 @@ def _test_deriv_balancer(): print("x grad = ", x.grad) +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + if __name__ == '__main__': _test_deriv_balancer() _test_exp_scale_swish() _test_exp_scale_relu() + _test_basic_norm() From b55472bb427a2407797e028bf929ff0b7f55f18b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 14:43:54 +0800 Subject: [PATCH 041/234] Replace most normalizations with scales (still have norm in conv) --- .../ASR/conformer_ctc/subsampling.py | 9 ++- .../ASR/transducer_stateless/conformer.py | 57 ++++++------------- 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 622495f21..29621bf52 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -58,6 +58,7 @@ class Conv2dSubsampling(nn.Module): ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out_norm = BasicNorm(odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -76,7 +77,7 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = x * 0.1 + x = self.out_norm(x) return x @@ -200,9 +201,11 @@ class PeLU(torch.nn.Module): return PeLUFunction.apply(x, self.cutoff, self.alpha) class ExpScale(torch.nn.Module): - def __init__(self, *shape, speed: float = 1.0): + def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0): super(ExpScale, self).__init__() - self.scale = nn.Parameter(torch.zeros(*shape)) + scale = torch.tensor(initial_scale) + scale = scale.log() / speed + self.scale = nn.Parameter(scale.detach()) self.speed = speed def forward(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 62d9f382f..acaf064b3 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer +from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm import torch from torch import Tensor, nn @@ -150,6 +150,8 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() + self.d_model = d_model + self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) @@ -174,22 +176,15 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) + self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) + self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) + self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block + self.norm_final = BasicNorm(d_model) self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before def forward( self, @@ -217,18 +212,15 @@ class ConformerEncoderLayer(nn.Module): # macaron style feed forward module residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) + + + src = src + self.dropout(self.feed_forward_macaron( + self.scale_ff_macaron(src))) + # multi-headed self-attention module residual = src - if self.normalize_before: - src = self.norm_mha(src) + src = self.scale_mha(src) src_att = self.self_attn( src, src, @@ -238,27 +230,14 @@ class ConformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, )[0] src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(self.scale_conv(src))) # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) + src = src + self.dropout(self.feed_forward(self.scale_ff(src))) - if self.normalize_before: - src = self.norm_final(src) + src = self.norm_final(src) return src @@ -288,7 +267,7 @@ class ConformerEncoder(nn.Module): self.aux_layers = set(aux_layers + [num_layers - 1]) assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.norm_final.weight.numel() + num_channels = encoder_layer.d_model self.combiner = RandomCombine(num_inputs=len(self.aux_layers), num_channels=num_channels, final_weight=0.5, From 87b843f02301738395a6d7c0651a295e11a92a08 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 14:44:55 +0800 Subject: [PATCH 042/234] Change exp dir --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 0dbd8479b..4fd4bf764 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 425e274c82029217623b263944aaa2b407ef5847 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 16:01:53 +0800 Subject: [PATCH 043/234] Replace norm in ConvolutionModule with a scaling factor. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 5 +++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index acaf064b3..4cf66e2fe 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -857,7 +857,8 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.norm = nn.LayerNorm(channels) + self.scale = ExpScale(1, speed=10.0, initial_scale=1.0) + # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() @@ -891,7 +892,7 @@ class ConvolutionModule(nn.Module): x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) - x = self.norm(x) + x = self.scale(x) x = x.permute(0, 2, 1) x = self.activation(x) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 4fd4bf764..c355c7ad3 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2fa9c636a44b105b18ef403afe6f4c1ff7d73529 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 23:24:55 +0800 Subject: [PATCH 044/234] use nonzero threshold in DerivBalancer --- .../ASR/conformer_ctc/subsampling.py | 47 +++++++++++++------ .../ASR/transducer_stateless/conformer.py | 6 +-- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 29621bf52..390d31115 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -219,7 +219,7 @@ def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: x = x * (scale * speed).exp() return x -class ExpScaleSwishFunction(torch.autograd.Function): +class SwishExpScaleFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) @@ -237,16 +237,16 @@ class ExpScaleSwishFunction(torch.autograd.Function): return x.grad, scale.grad, None -class ExpScaleSwish(torch.nn.Module): - # combines ExpScale an Swish - # caution: need to specify name for speed, e.g. ExpScaleSwish(50, speed=4.0) +class SwishExpScale(torch.nn.Module): + # combines ExpScale and a Swish (actually the ExpScale is after the Swish). + # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) def __init__(self, *shape, speed: float = 1.0): - super(ExpScaleSwish, self).__init__() + super(SwishExpScale, self).__init__() self.scale = nn.Parameter(torch.zeros(*shape)) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return ExpScaleSwishFunction.apply(x, self.scale, self.speed) + return SwishExpScaleFunction.apply(x, self.scale, self.speed) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() @@ -313,13 +313,15 @@ class ExpScaleRelu(torch.nn.Module): class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: 0.05, max_factor: 0.05, - epsilon: 1.0e-10) -> Tensor: + threshold: float = 0.05, + max_factor: float = 0.05, + zero: float = 0.02, + epsilon: float = 1.0e-10) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(x.dtype), dim=sum_dims, keepdim=True) + proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) ctx.save_for_backward(factor) @@ -328,7 +330,7 @@ class DerivBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: factor, = ctx.saved_tensors neg_delta_grad = x_grad.abs() * factor if ctx.epsilon != 0.0: @@ -336,7 +338,7 @@ class DerivBalancerFunction(torch.autograd.Function): deriv_is_zero = (sum_abs_grad == 0.0) neg_delta_grad += ctx.epsilon * deriv_is_zero - return x_grad - neg_delta_grad, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -429,20 +431,37 @@ class DerivBalancer(torch.nn.Module): When all grads are zero for a channel, this module sets all the input derivatives for that channel to -epsilon; the idea is to bring completely dead neurons back to life this way. + + Args: + channel_dim: the dimension/axi corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + threshold: the threshold, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives, + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.01]. + zero: we use this value in the comparison (x > 0), i.e. we actually use + (x > zero). The reason for using a threshold slightly greater + than zero is that it will tend to prevent situations where the + inputs shrink close to zero and the nonlinearity (e.g. swish) + behaves like a linear function and we learn nothing. """ def __init__(self, channel_dim: int, threshold: float = 0.05, - max_factor: float = 0.05, + max_factor: float = 0.02, + zero: float = 0.02, epsilon: float = 1.0e-10): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor + self.zero = zero self.epsilon = epsilon def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.epsilon) + self.max_factor, self.zero, + self.epsilon) @@ -455,7 +474,7 @@ def _test_exp_scale_swish(): x1 = torch.randn(50, 60).detach() x2 = x1.detach() - m1 = ExpScaleSwish(50, 1, speed=4.0) + m1 = SwishExpScale(50, 1, speed=4.0) m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) x1.requires_grad = True x2.requires_grad = True diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 4cf66e2fe..7a7a09c27 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, ExpScaleSwish, ExpScaleRelu, DerivBalancer, BasicNorm +from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm import torch from torch import Tensor, nn @@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.025), - ExpScaleSwish(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c355c7ad3..36a1ae869 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 76560f255c3ab5f88f8bf318c14fc5d81eb9c429 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 23:48:46 +0800 Subject: [PATCH 045/234] Add min-abs-value 0.2 --- .../ASR/conformer_ctc/subsampling.py | 72 ++++++++++++------- .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 390d31115..d1ff7f233 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -312,33 +312,36 @@ class ExpScaleRelu(torch.nn.Module): class DerivBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, channel_dim: int, + def forward(ctx, x: Tensor, + channel_dim: int, threshold: float = 0.05, max_factor: float = 0.05, - zero: float = 0.02, - epsilon: float = 1.0e-10) -> Tensor: + min_abs: float = 0.2) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > zero).to(x.dtype), dim=sum_dims, keepdim=True) + xgt0 = x > 0 + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) - ctx.save_for_backward(factor) - ctx.epsilon = epsilon + below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) + + ctx.save_for_backward(factor, xgt0, below_threshold) + ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - factor, = ctx.saved_tensors - neg_delta_grad = x_grad.abs() * factor - if ctx.epsilon != 0.0: - sum_abs_grad = torch.sum(x_grad.abs(), dim=ctx.sum_dims, keepdim=True) - deriv_is_zero = (sum_abs_grad == 0.0) - neg_delta_grad += ctx.epsilon * deriv_is_zero + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: + factor, xgt0, below_threshold = ctx.saved_tensors + dtype = x_grad.dtype + too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) - return x_grad - neg_delta_grad, None, None, None, None, None + neg_delta_grad = x_grad.abs() * (factor + too_small_factor) + + + return x_grad - neg_delta_grad, None, None, None, None class BasicNorm(torch.nn.Module): @@ -449,19 +452,17 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.02, - zero: float = 0.02, - epsilon: float = 1.0e-10): + min_abs: float = 0.2): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor - self.zero = zero - self.epsilon = epsilon + self.min_abs = min_abs def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.zero, - self.epsilon) + self.max_factor, self.min_abs) + @@ -505,23 +506,41 @@ def _test_exp_scale_relu(): -def _test_deriv_balancer(): +def _test_deriv_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) N = 500 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, epsilon=1.0e-10) + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) y_grad = torch.sign(torch.randn(probs.numel(), N)) y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) - print("x = ", x) - print("y grad = ", y_grad) - print("x grad = ", x.grad) + print("_test_deriv_balancer_sign: x = ", x) + print("_test_deriv_balancer_sign: y grad = ", y_grad) + print("_test_deriv_balancer_sign: x grad = ", x.grad) + +def _test_deriv_balancer_magnitude(): + channel_dim = 0 + magnitudes = torch.arange(0, 1, 0.01) + N = 500 + x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + y_grad[-1,:] = 0 + + y = m(x) + y.backward(gradient=y_grad) + print("_test_deriv_balancer_magnitude: x = ", x) + print("_test_deriv_balancer_magnitude: y grad = ", y_grad) + print("_test_deriv_balancer_magnitude: x grad = ", x.grad) def _test_basic_norm(): @@ -543,7 +562,8 @@ def _test_basic_norm(): if __name__ == '__main__': - _test_deriv_balancer() + _test_deriv_balancer_sign() + _test_deriv_balancer_magnitude() _test_exp_scale_swish() _test_exp_scale_relu() _test_basic_norm() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 36a1ae869..618d90490 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2z0.02", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From bfce5f63e498877f4d3c9681a0da341ce90d2e67 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 10 Mar 2022 23:49:09 +0800 Subject: [PATCH 046/234] Fix dirname --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 618d90490..d75341a07 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.1", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From e3e14cf7a4850d2a7785a53a2bb7d47c53b44310 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:16:33 +0800 Subject: [PATCH 047/234] Change min-abs threshold from 0.2 to 0.5 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index d1ff7f233..d7be46f17 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -316,7 +316,7 @@ class DerivBalancerFunction(torch.autograd.Function): channel_dim: int, threshold: float = 0.05, max_factor: float = 0.05, - min_abs: float = 0.2) -> Tensor: + min_abs: float = 0.5) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim @@ -452,7 +452,7 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.02, - min_abs: float = 0.2): + min_abs: float = 0.5): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index d75341a07..80febc677 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.2", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From ab9a17413ab966e013d275919791040c23002407 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:37:52 +0800 Subject: [PATCH 048/234] Scale up pos_bias_u and pos_bias_v before use. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 +++- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7a7a09c27..d0be5af00 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -614,7 +614,9 @@ class RelPositionMultiheadAttention(nn.Module): assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + q = q * scaling if torch.equal(query, key) and torch.equal(key, value): # self-attention @@ -764,7 +766,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = ( matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + ) # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 80febc677..c9654cc94 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5", + default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 137eae0b95ee5a0c4dcd137c3d1279301006c5ee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:41:55 +0800 Subject: [PATCH 049/234] Reduce max_factor to 0.01 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index d7be46f17..2e4eb754b 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -451,7 +451,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, threshold: float = 0.05, - max_factor: float = 0.02, + max_factor: float = 0.01, min_abs: float = 0.5): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim From 2940d3106f08a06d37618229e872ace7e371fa66 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:43:57 +0800 Subject: [PATCH 050/234] Fix q*scaling logic --- egs/librispeech/ASR/transducer_stateless/conformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index d0be5af00..e14c7a02e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -616,7 +616,6 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 - q = q * scaling if torch.equal(query, key) and torch.equal(key, value): # self-attention @@ -721,7 +720,7 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = (q.contiguous() * scaling).view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) From bcf417fce2b3115e7dda8d3b0e0a6cbafd5d71ac Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 14:47:46 +0800 Subject: [PATCH 051/234] Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 2e4eb754b..ce25ad8ea 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -48,13 +48,13 @@ class Conv2dSubsampling(nn.Module): in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index e14c7a02e..051512969 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -168,7 +168,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.025), + max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), @@ -720,7 +720,7 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q.contiguous() * scaling).view(tgt_len, bsz, num_heads, head_dim) + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) From bec33e6855afba8c4f2739cbcf6a7a67398ec210 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 16:37:17 +0800 Subject: [PATCH 052/234] init 1st conv module to smaller variance --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 8 ++++++++ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ce25ad8ea..6a697aa0e 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -59,6 +59,14 @@ class Conv2dSubsampling(nn.Module): ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) + self._reset_parameters() + + def _reset_parameters(self): + # init weights with smaller than default variance, because otherwise + # they learn too slowly in relative terms (assuming we're training with adam). + nn.init.normal_(self.conv[0].weight, std=0.05) + nn.init.constant_(self.conv[0].bias, 0.0) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c9654cc94..5d6d72490 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline_randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs", + default="transducer_stateless/randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs_cinit", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5eafccb36942b3da42024c48b1af237ef1f613ec Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 17:46:33 +0800 Subject: [PATCH 053/234] Change how scales are applied; fix residual bug --- .../ASR/transducer_stateless/conformer.py | 17 +++++++++++++---- .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 051512969..2c602bbea 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -229,10 +229,17 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + self.dropout(src_att) + # natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it + # to 0.2 to 0.6, which is suitable to add to the inputs assuming the output + # of the previous convolution layer had a magnitude of around 1.0 + # (this magnitude of 1.0, or a bit less, like 0.3, is learned but is + # dictated by considerations of what is done to the output of the + # encoder. + post_scale_mha = 0.1 + src = residual + post_scale_mha * self.dropout(src_att) # convolution module - src = residual + self.dropout(self.conv_module(self.scale_conv(src))) + src = src + self.dropout(self.conv_module(self.scale_conv(src))) # feed forward module src = src + self.dropout(self.feed_forward(self.scale_ff(src))) @@ -891,13 +898,15 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) + + # TODO: can have a learned scale in here, or a fixed one. + x = self.activation(x) + # x is (batch, channels, time) x = x.permute(0, 2, 1) x = self.scale(x) x = x.permute(0, 2, 1) - x = self.activation(x) - x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 5d6d72490..b5e9e846f 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_brelu2swish2_0.1_bnorm2ma0.5_pbs_cinit", + default="transducer_stateless/randcombine1_expscale3_rework", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a0d5e2932ccfd2b2eadb271434dc30a14c980c7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 18:17:49 +0800 Subject: [PATCH 054/234] Reduce min_abs from 0.5 to 0.2 --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6a697aa0e..6b1cb128f 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -460,7 +460,7 @@ class DerivBalancer(torch.nn.Module): def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.01, - min_abs: float = 0.5): + min_abs: float = 0.2): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold From 98156711efb11e92d8b50eb426041b62da4a5564 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 19:05:55 +0800 Subject: [PATCH 055/234] Introduce in_scale=0.5 for SwishExpScale --- .../ASR/conformer_ctc/subsampling.py | 19 ++++++++++++------- .../ASR/transducer_stateless/conformer.py | 4 ++-- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6b1cb128f..52a58d104 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -221,18 +221,21 @@ class ExpScale(torch.nn.Module): -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: # double-swish, implemented/approximated as offset-swish + if in_scale != 1.0: + x = x * in_scale x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x class SwishExpScaleFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: + def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed - return _exp_scale_swish(x, scale, speed) + ctx.in_scale = in_scale + return _exp_scale_swish(x, scale, speed, in_scale) @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: @@ -240,21 +243,23 @@ class SwishExpScaleFunction(torch.autograd.Function): x.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) + y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale) y.backward(gradient=y_grad) - return x.grad, scale.grad, None + return x.grad, scale.grad, None, None class SwishExpScale(torch.nn.Module): # combines ExpScale and a Swish (actually the ExpScale is after the Swish). # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) - def __init__(self, *shape, speed: float = 1.0): + # + def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): super(SwishExpScale, self).__init__() + self.in_scale = in_scale self.scale = nn.Parameter(torch.zeros(*shape)) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed) + return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 2c602bbea..7b9aff71f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b5e9e846f..190406491 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework", + default="transducer_stateless/randcombine1_expscale3_rework_0.5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From cc558faf262f7db5bfbc637e86a7102f23c1f77e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 19:11:50 +0800 Subject: [PATCH 056/234] Fix scale from 0.5 to 2.0 as I really intended.. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 7b9aff71f..fa25e6ca0 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -160,7 +160,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) @@ -169,7 +169,7 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=0.5), + SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 190406491..c72a9dd28 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_0.5", + default="transducer_stateless/randcombine1_expscale3_rework_2.0", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 2d3a76292d0649a358f39835cd5944c0ac406b37 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 20:12:45 +0800 Subject: [PATCH 057/234] Set scaling on SwishExpScale --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 52a58d104..caac230ed 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -255,7 +255,9 @@ class SwishExpScale(torch.nn.Module): def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): super(SwishExpScale, self).__init__() self.in_scale = in_scale - self.scale = nn.Parameter(torch.zeros(*shape)) + initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed + initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach() + self.scale = nn.Parameter(initial_log_scale) self.speed = speed def forward(self, x: Tensor) -> Tensor: From 7eb5a84cbeb4242736b28d1d1ea5a118cb1cc256 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 11 Mar 2022 21:00:43 +0800 Subject: [PATCH 058/234] Add identity pre_norm_final for diagnostics. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fa25e6ca0..389a7cb7f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -181,6 +181,7 @@ class ConformerEncoderLayer(nn.Module): self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) + self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) self.dropout = nn.Dropout(dropout) @@ -244,7 +245,7 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(self.scale_ff(src))) - src = self.norm_final(src) + src = self.norm_final(self.pre_norm_final(src)) return src @@ -930,8 +931,9 @@ class SwishOffset(torch.nn.Module): return x * torch.sigmoid(x + self.offset) -def identity(x): - return x +class Identity(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return x class RandomCombine(torch.nn.Module): From 76a2b9d36239566aae2125837f653ecbeb3a1ca9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 11:19:49 +0800 Subject: [PATCH 059/234] Add learnable post-scale for mha --- egs/librispeech/ASR/transducer_stateless/conformer.py | 10 ++-------- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 389a7cb7f..963cb2cd9 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -177,6 +177,7 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) + self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0) self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) @@ -230,14 +231,7 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - # natural rms scale of mha output is about 2 to 6. scaling down by 0.1 takes it - # to 0.2 to 0.6, which is suitable to add to the inputs assuming the output - # of the previous convolution layer had a magnitude of around 1.0 - # (this magnitude of 1.0, or a bit less, like 0.3, is learned but is - # dictated by considerations of what is done to the output of the - # encoder. - post_scale_mha = 0.1 - src = residual + post_scale_mha * self.dropout(src_att) + src = residual + post_scale_mha(self.dropout(src_att)) # convolution module src = src + self.dropout(self.conv_module(self.scale_conv(src))) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c72a9dd28..be771b517 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_2.0", + default="transducer_stateless/randcombine1_expscale3_rework_2.0_b", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 0abba9e7a2eb849164459ee5ec22d7b2da28d9c5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 11:20:44 +0800 Subject: [PATCH 060/234] Fix self.post-scale-mha --- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 963cb2cd9..3f9becded 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -231,7 +231,7 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + post_scale_mha(self.dropout(src_att)) + src = residual + self.post_scale_mha(self.dropout(src_att)) # convolution module src = src + self.dropout(self.conv_module(self.scale_conv(src))) From ca8cf2a73b4d65406d9ca5b4648af4768926d3b3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 15:38:13 +0800 Subject: [PATCH 061/234] Another rework, use scales on linear/conv --- .../ASR/conformer_ctc/subsampling.py | 156 ++++++++++++------ .../ASR/transducer_stateless/conformer.py | 73 ++++---- 2 files changed, 140 insertions(+), 89 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index caac230ed..831537d79 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -44,20 +44,20 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d( + ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), - nn.Conv2d( + ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), ExpScaleRelu(odim, 1, 1, speed=20.0), ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) self._reset_parameters() @@ -221,21 +221,18 @@ class ExpScale(torch.nn.Module): -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: +def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: # double-swish, implemented/approximated as offset-swish - if in_scale != 1.0: - x = x * in_scale x = (x * torch.sigmoid(x - 1.0)) x = x * (scale * speed).exp() return x class SwishExpScaleFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float, in_scale: float) -> Tensor: + def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: ctx.save_for_backward(x.detach(), scale.detach()) ctx.speed = speed - ctx.in_scale = in_scale - return _exp_scale_swish(x, scale, speed, in_scale) + return _exp_scale_swish(x, scale, speed) @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: @@ -243,25 +240,24 @@ class SwishExpScaleFunction(torch.autograd.Function): x.requires_grad = True scale.requires_grad = True with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed, ctx.in_scale) + y = _exp_scale_swish(x, scale, ctx.speed) y.backward(gradient=y_grad) - return x.grad, scale.grad, None, None + return x.grad, scale.grad, None class SwishExpScale(torch.nn.Module): # combines ExpScale and a Swish (actually the ExpScale is after the Swish). # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) # - def __init__(self, *shape, speed: float = 1.0, in_scale: float = 1.0): + def __init__(self, *shape, speed: float = 1.0): super(SwishExpScale, self).__init__() - self.in_scale = in_scale - initial_log_scale = torch.tensor(1.0 / in_scale).log() / speed - initial_log_scale = (torch.ones(*shape) * initial_log_scale).detach() + + initial_log_scale = torch.zeros(()).detach() self.scale = nn.Parameter(initial_log_scale) self.speed = speed def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed, self.in_scale) + return SwishExpScaleFunction.apply(x, self.scale, self.speed) # x = (x * torch.sigmoid(x)) # x = (x * torch.sigmoid(x)) # x = x * (self.scale * self.speed).exp() @@ -383,12 +379,11 @@ class BasicNorm(torch.nn.Module): interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - initial_eps_scale: a constant that determines the initial - "epsilon" that we add as ballast in: - scale = output_scale * ((input_vec**2).sum() + epsilon)**-0.5 - Note: our epsilon is actually large, not small, but we keep the name - to indicate the connection with normal LayerNorm. We set - epsilon initially to num_channels * initial_eps_scale. + initial_eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with normal LayerNorm. + speed: a scaling factor that can be interpreted as scaling the learning rate for this module. CAUTION: the default value of 10.0 intended to be used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. @@ -398,42 +393,101 @@ class BasicNorm(torch.nn.Module): def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - initial_eps_scale: float = 0.25, - speed: float = 10.0): + eps: float = 0.25): super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.speed = speed - eps = num_channels * initial_eps_scale - # log_eps = log(eps) / speed - log_eps = torch.tensor(eps).log() / speed - self.log_eps = nn.Parameter(log_eps.detach()) - # initial output-scale, to get LayerNorm-like behavior, is - # sqrt(num_channels). - initial_scale = torch.tensor(num_channels ** 0.5).log() / speed - self.log_scale = nn.Parameter(initial_scale.detach()) - - def _inner(self, x: Tensor) -> Tensor: - # inner product on last dim of x, keeping the dimension, - # i.e. torch.sum(x**2, dim=-1, keepdim=True), but more - # efficient. - if hasattr(torch, 'inner'): - return torch.inner(x).unsqueeze(-1) - else: - # TODO: we can do this with matrix multiplication, maybe.a - return torch.sum(x**2, dim=-1, keepdim=True) + self.eps = eps def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - x = x.transpose(-1, self.channel_dim) - eps = (self.log_eps * self.speed).exp() - out_scale = (self.log_scale * self.speed).exp() + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5 + return x * scales + + +class ScaledLinear(nn.Linear): + def __init__(self, *args, scale_speed=5.0, **kwargs): + super(ScaledLinear, self).__init__(*args, **kwargs) + self.weight_scale = nn.Parameter(torch.zeros(())) + self.scale_speed = scale_speed + if self.bias is not None: + self.bias_scale = nn.Parameter(torch.zeros(())) + else: + self.register_parameter('bias_scale', None) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) + + +class ScaledConv1d(nn.Conv1d): + def __init__(self, *args, scale_speed = 5.0, **kwargs): + super(ScaledConv1d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + self.weight_scale = nn.Parameter(torch.zeros(())) + if self.bias is not None: + self.bias_scale = nn.Parameter(torch.zeros(())) + else: + self.register_parameter('bias_scale', None) + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + + +class ScaledConv2d(nn.Conv2d): + def __init__(self, *args, scale_speed=5.0, **kwargs): + super(ScaledConv2d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + self.weight_scale = nn.Parameter(torch.zeros(())) + if self.bias is not None: + self.bias_scale = nn.Parameter(torch.zeros(())) + else: + self.register_parameter('bias_scale', None) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) - scales = out_scale * (self._inner(x) + eps) ** -0.5 - x = x * scales - x = x.transpose(-1, self.channel_dim) - return x @@ -576,6 +630,8 @@ def _test_basic_norm(): + + if __name__ == '__main__': _test_deriv_balancer_sign() _test_deriv_balancer_magnitude() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3f9becded..93f7dd170 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm +from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -157,30 +157,25 @@ class ConformerEncoderLayer(nn.Module): ) self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), + ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), + ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1, threshold=0.05, max_factor=0.01), - SwishExpScale(dim_feedforward, speed=20.0, in_scale=2.0), + SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.scale_mha = ExpScale(1, speed=10.0, initial_scale=0.2) - self.post_scale_mha = ExpScale(1, speed=10.0, initial_scale=1.0) - self.scale_conv = ExpScale(1, speed=10.0, initial_scale=0.5) - self.scale_ff = ExpScale(1, speed=10.0, initial_scale=0.5) - self.scale_ff_macaron = ExpScale(1, speed=10.0, initial_scale=0.5) self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) @@ -216,13 +211,10 @@ class ConformerEncoderLayer(nn.Module): residual = src - src = src + self.dropout(self.feed_forward_macaron( - self.scale_ff_macaron(src))) + src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module - residual = src - src = self.scale_mha(src) src_att = self.self_attn( src, src, @@ -231,13 +223,13 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = residual + self.post_scale_mha(self.dropout(src_att)) + src = src + self.dropout(src_att) # convolution module - src = src + self.dropout(self.conv_module(self.scale_conv(src))) + src = src + self.dropout(self.conv_module(src)) # feed forward module - src = src + self.dropout(self.feed_forward(self.scale_ff(src))) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.pre_norm_final(src)) @@ -420,6 +412,7 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, + scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -430,18 +423,27 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.scale_speed = scale_speed + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() + def _pos_bias_u(self): + return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + + def _pos_bias_v(self): + return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -508,11 +510,11 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.weight, - self.in_proj.bias, + self.in_proj.get_weight(), + self.in_proj.get_bias(), self.dropout, - self.out_proj.weight, - self.out_proj.bias, + self.out_proj.get_weight(), + self.out_proj.get_bias(), training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -743,11 +745,11 @@ class RelPositionMultiheadAttention(nn.Module): p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose( + q_with_bias_u = (q + self._pos_bias_u()).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose( + q_with_bias_v = (q + self._pos_bias_v()).transpose( 1, 2 ) # (batch, head, time1, d_k) @@ -842,7 +844,7 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = nn.Conv1d( + self.pointwise_conv1 = ScaledConv1d( channels, 2 * channels, kernel_size=1, @@ -850,7 +852,7 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - self.depthwise_conv = nn.Conv1d( + self.depthwise_conv = ScaledConv1d( channels, channels, kernel_size, @@ -860,12 +862,10 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.scale = ExpScale(1, speed=10.0, initial_scale=1.0) - # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() - self.pointwise_conv2 = nn.Conv1d( + self.pointwise_conv2 = ScaledConv1d( channels, channels, kernel_size=1, @@ -897,11 +897,6 @@ class ConvolutionModule(nn.Module): # TODO: can have a learned scale in here, or a fixed one. x = self.activation(x) - # x is (batch, channels, time) - x = x.permute(0, 2, 1) - x = self.scale(x) - x = x.permute(0, 2, 1) - x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) @@ -982,7 +977,7 @@ class RandomCombine(torch.nn.Module): assert pure_prob >= 0 and pure_prob <= 1 assert final_weight > 0 and final_weight < 1 assert num_inputs >= 1 - self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True) + self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True) for _ in range(num_inputs - 1)]) self.num_inputs = num_inputs From d906bc2a4f14fd9394363e3ec6d473d9ed2aff3b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 15:38:39 +0800 Subject: [PATCH 062/234] Change dir name --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index be771b517..1a57d654f 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework_2.0_b", + default="transducer_stateless/randcombine1_expscale3_rework2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a392cb9fbc5bc23228ff142354c7962b59fdaa74 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 16:53:03 +0800 Subject: [PATCH 063/234] Reduce initial scaling of modules --- .../ASR/conformer_ctc/subsampling.py | 22 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 2 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 831537d79..dab0e1e1d 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -407,12 +407,13 @@ class BasicNorm(torch.nn.Module): class ScaledLinear(nn.Linear): - def __init__(self, *args, scale_speed=5.0, **kwargs): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) self.scale_speed = scale_speed if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) @@ -431,12 +432,14 @@ class ScaledLinear(nn.Linear): class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, **kwargs): + def __init__(self, *args, scale_speed = 5.0, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) self.scale_speed = scale_speed - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) @@ -459,12 +462,13 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, **kwargs): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) self.scale_speed = scale_speed - self.weight_scale = nn.Parameter(torch.zeros(())) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: - self.bias_scale = nn.Parameter(torch.zeros(())) + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 93f7dd170..aa35f5e7e 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module): max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 1a57d654f..b871efd13 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2", + default="transducer_stateless/randcombine1_expscale3_rework2b", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a24572abd1285ff12c89e31908694689fa2e6d41 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 17:28:43 +0800 Subject: [PATCH 064/234] Bug-fix RE bias --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index dab0e1e1d..5f1e376a9 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -486,7 +486,7 @@ class ScaledConv2d(nn.Conv2d): return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, self.get_bias(), self.stride, _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, + return F.conv2d(input, weight, self.get_bias(), self.stride, self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: From b7b2d8970b608ff3954039e99dbbd95186b61bae Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 17:47:35 +0800 Subject: [PATCH 065/234] Cosmetic change --- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index aa35f5e7e..a270cd8ae 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -171,7 +171,7 @@ class ConformerEncoderLayer(nn.Module): max_factor=0.01), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -208,9 +208,6 @@ class ConformerEncoderLayer(nn.Module): """ # macaron style feed forward module - residual = src - - src = src + self.dropout(self.feed_forward_macaron(src)) @@ -872,6 +869,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, + initial_scale=0.5 ) def forward(self, x: Tensor) -> Tensor: From db7a3b6eea34e532240dae3409c6d64e8eab9806 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 18:50:02 +0800 Subject: [PATCH 066/234] Reduce initial_scale. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index a270cd8ae..9dd6bae4d 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -421,7 +421,7 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -869,7 +869,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, - initial_scale=0.5 + initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: From be0a79cbcae9fb6a02f139ef4385af7fa6f80032 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 19:00:48 +0800 Subject: [PATCH 067/234] Replace ExpScaleRelu with DoubleSwish() --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 5f1e376a9..13259d166 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -49,13 +49,13 @@ class Conv2dSubsampling(nn.Module): ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), - ExpScaleRelu(odim, 1, 1, speed=20.0), + DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), DerivBalancer(channel_dim=1, threshold=0.05, max_factor=0.01), - ExpScaleRelu(odim, 1, 1, speed=20.0), + DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) From 2117f46361c2b2deb63194de43098e1a17714d61 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 19:02:14 +0800 Subject: [PATCH 068/234] DoubleSwish fix --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 13259d166..6bf0aefe4 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -537,13 +537,13 @@ class DerivBalancer(torch.nn.Module): self.max_factor, self.min_abs) +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x - 1.0) def _test_exp_scale_swish(): - class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x - 1.0) x1 = torch.randn(50, 60).detach() x2 = x1.detach() From 6042c96db2f68c24f08aadf93904d0383dcd7fc9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Mar 2022 20:54:46 +0800 Subject: [PATCH 069/234] Use learnable scales for joiner and decoder --- .../ASR/transducer_stateless/decoder.py | 187 +++++++++++++++++- .../ASR/transducer_stateless/joiner.py | 4 +- .../ASR/transducer_stateless/train.py | 2 +- .../ASR/transducer_stateless/transformer.py | 4 +- 4 files changed, 190 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index 003b03a2e..bc4bcb3f6 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -17,6 +17,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -52,7 +55,7 @@ class Decoder(nn.Module): 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() - self.embedding = nn.Embedding( + self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, @@ -62,7 +65,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: - self.conv = nn.Conv1d( + self.conv = ScaledConv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, @@ -97,3 +100,183 @@ class Decoder(nn.Module): embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None, + scale_speed: float = 5.0) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale_speed = scale_speed + self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) + + if _weight is None: + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + else: + assert list(_weight.shape) == [num_embeddings, embedding_dim], \ + 'Shape of weight does not match num_embeddings and embedding_dim' + self.weight = nn.Parameter(_weight) + self.sparse = sparse + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = (self.scale * self.scale_speed).exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Creates Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): See module initialization documentation. + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + embedding.weight.requires_grad = not freeze + return embedding diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 9fd9da4f1..8311461d3 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn - +from subsampling import ScaledLinear class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): @@ -24,7 +24,7 @@ class Joiner(nn.Module): self.input_dim = input_dim self.output_dim = output_dim - self.output_linear = nn.Linear(input_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim) def forward( self, diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index b871efd13..c2202fe1e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2b", + default="transducer_stateless/randcombine1_expscale3_rework2c", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py index e851dcc32..3fa847f4f 100644 --- a/egs/librispeech/ASR/transducer_stateless/transformer.py +++ b/egs/librispeech/ASR/transducer_stateless/transformer.py @@ -21,7 +21,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn from encoder_interface import EncoderInterface -from subsampling import Conv2dSubsampling, VggSubsampling +from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear from icefall.utils import make_pad_mask @@ -106,7 +106,7 @@ class Transformer(EncoderInterface): # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) ) def forward( From e6a501d3c87222292eb83f0a2a158835e85606ba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 11:52:13 +0800 Subject: [PATCH 070/234] Add max-abs-value constraint in DerivBalancer --- .../ASR/conformer_ctc/subsampling.py | 42 +++++++++++++------ .../ASR/transducer_stateless/train.py | 2 +- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6bf0aefe4..ea0204138 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -325,9 +325,11 @@ class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: float = 0.05, - max_factor: float = 0.05, - min_abs: float = 0.5) -> Tensor: + threshold: float, # e.g. 0.05 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 1000.0 + ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim @@ -336,23 +338,26 @@ class DerivBalancerFunction(torch.autograd.Function): proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) factor = (threshold - proportion_positive).relu() * (max_factor / threshold) - below_threshold = (torch.mean(x.abs(), dim=sum_dims, keepdim=True) < min_abs) + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) - ctx.save_for_backward(factor, xgt0, below_threshold) + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - factor, xgt0, below_threshold = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - too_small_factor = below_threshold.to(dtype) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) - neg_delta_grad = x_grad.abs() * (factor + too_small_factor) + neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -521,20 +526,33 @@ class DerivBalancer(torch.nn.Module): than zero is that it will tend to prevent situations where the inputs shrink close to zero and the nonlinearity (e.g. swish) behaves like a linear function and we learn nothing. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. This is to prevent a failure mode where the activations + become so small that the nonlinearity effectively becomes linear, + which makes the module useless and it gets even smaller + to try to "turn it off" completely. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. This is to prevent the possibility of activations getting + out of floating point numerical range (especially in half precision). """ def __init__(self, channel_dim: int, threshold: float = 0.05, max_factor: float = 0.01, - min_abs: float = 0.2): + min_abs: float = 0.2, + max_abs: float = 1000.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.threshold = threshold self.max_factor = max_factor self.min_abs = min_abs + self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, - self.max_factor, self.min_abs) + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwish(torch.nn.Module): diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index c2202fe1e..1434d6da4 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 5d69acb25b45d80e554c55d8dbc0aacc3432217a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 13:15:20 +0800 Subject: [PATCH 071/234] Add max-abs-value --- .../ASR/conformer_ctc/subsampling.py | 52 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 6 +-- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index ea0204138..8d01d8fc0 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,14 +47,12 @@ class Conv2dSubsampling(nn.Module): ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -325,7 +323,8 @@ class DerivBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, - threshold: float, # e.g. 0.05 + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 max_factor: float, # e.g. 0.01 min_abs: float, # e.g. 0.2 max_abs: float, # e.g. 1000.0 @@ -336,7 +335,13 @@ class DerivBalancerFunction(torch.autograd.Function): sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor = (threshold - proportion_positive).relu() * (max_factor / threshold) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) below_threshold = (mean_abs < min_abs) @@ -348,16 +353,14 @@ class DerivBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) - - - return x_grad - neg_delta_grad, None, None, None, None, None + return x_grad - neg_delta_grad, None, None, None, None, None, None class BasicNorm(torch.nn.Module): @@ -516,7 +519,9 @@ class DerivBalancer(torch.nn.Module): Args: channel_dim: the dimension/axi corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - threshold: the threshold, per channel, of the proportion of the time + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_factor: the maximum factor by which we modify the derivatives, e.g. with max_factor=0.02, the the derivatives would be multiplied by @@ -538,19 +543,22 @@ class DerivBalancer(torch.nn.Module): out of floating point numerical range (especially in half precision). """ def __init__(self, channel_dim: int, - threshold: float = 0.05, + min_positive: float = 0.05, + max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 1000.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim - self.threshold = threshold + self.min_positive = min_positive + self.max_positive = max_positive self.max_factor = max_factor self.min_abs = min_abs self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return DerivBalancerFunction.apply(x, self.channel_dim, self.threshold, + return DerivBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, self.max_factor, self.min_abs, self.max_abs) @@ -600,14 +608,14 @@ def _test_exp_scale_relu(): def _test_deriv_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) - N = 500 + N = 1000 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) - y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) @@ -618,14 +626,16 @@ def _test_deriv_balancer_sign(): def _test_deriv_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) - N = 500 - x = 1.0 * (torch.randn(magnitudes.numel(), N) * magnitudes.unsqueeze(-1)) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, threshold=0.05, max_factor=0.2, min_abs=0.2) + m = DerivBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - y_grad[-1,:] = 0 y = m(x) y.backward(gradient=y_grad) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 9dd6bae4d..3516c2205 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -158,8 +158,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=-1), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -167,8 +166,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1, threshold=0.05, - max_factor=0.01), + DerivBalancer(channel_dim=-1), SwishExpScale(dim_feedforward, speed=20.0), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), From 97c0bb82d329426d80d535348651bceaab58df1c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 13:19:20 +0800 Subject: [PATCH 072/234] Change dir name --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 1434d6da4..897cf5411 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From f351777e9cc0da74e96212782f9056057b2407a6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 17:29:39 +0800 Subject: [PATCH 073/234] Remove ExpScale in feedforward layes. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 13 +++++++++---- .../ASR/transducer_stateless/conformer.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 8d01d8fc0..04481aa5b 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -565,8 +565,13 @@ class DerivBalancer(torch.nn.Module): class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x - 1.0) + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient + backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1) + """ + x1 = x - 1.0 + s = torch.sigmoid(x1) + return (x1 * s) + s # (x-1) * s + s == x * s def _test_exp_scale_swish(): @@ -581,10 +586,10 @@ def _test_exp_scale_swish(): y1 = m1(x1) y2 = m2(x2) - assert torch.allclose(y1, y2) + assert torch.allclose(y1, y2, atol=1e-05) y1.sum().backward() y2.sum().backward() - assert torch.allclose(x1.grad, x2.grad) + assert torch.allclose(x1.grad, x2.grad, atol=1e-05) def _test_exp_scale_relu(): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3516c2205..e6466d8e6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1), - SwishExpScale(dim_feedforward, speed=20.0), + DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) @@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), DerivBalancer(channel_dim=-1), - SwishExpScale(dim_feedforward, speed=20.0), + DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 897cf5411..994b89e49 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 437e8b208341bf027744be5d81f0126635150572 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 13 Mar 2022 23:31:08 +0800 Subject: [PATCH 074/234] Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. --- .../ASR/conformer_ctc/subsampling.py | 4 ++-- .../ASR/transducer_stateless/conformer.py | 22 ++++++++++++++++++- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 04481aa5b..3a1eda3f1 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -327,7 +327,7 @@ class DerivBalancerFunction(torch.autograd.Function): max_positive: float, # e.g. 0.95 max_factor: float, # e.g. 0.01 min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 1000.0 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: @@ -547,7 +547,7 @@ class DerivBalancer(torch.nn.Module): max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, - max_abs: float = 1000.0): + max_abs: float = 100.0): super(DerivBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index e6466d8e6..65a8431de 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -847,6 +847,22 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.depthwise_conv = ScaledConv1d( channels, channels, @@ -857,6 +873,8 @@ class ConvolutionModule(nn.Module): bias=bias, ) + + self.deriv_balancer2 = DerivBalancer(channel_dim=1) # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() @@ -885,12 +903,14 @@ class ConvolutionModule(nn.Module): # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) - # TODO: can have a learned scale in here, or a fixed one. + x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 994b89e49..a0395a398 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From ae2568825396f63b6c3a68eb3e8d6e132d407da9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Mar 2022 11:02:32 +0800 Subject: [PATCH 075/234] Make DoubleSwish more memory efficient --- .../ASR/conformer_ctc/subsampling.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 04481aa5b..6ff9be4e6 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -267,23 +267,6 @@ def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: return (x * (scale * speed).exp()).relu() -class ExpScaleReluFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_swish(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - class ExpScaleReluFunction(torch.autograd.Function): @@ -563,16 +546,32 @@ class DerivBalancer(torch.nn.Module): self.max_abs) +def _double_swish(x: Tensor) -> Tensor: + # double-swish, implemented/approximated as offset-swish + return x * torch.sigmoid(x - 1.0) + +class DoubleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + ctx.save_for_backward(x.detach()) + return _double_swish(x) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + # TODO: can make this more efficient. + x, = ctx.saved_tensors + x.requires_grad = True + with torch.enable_grad(): + y = _double_swish(x) + y.backward(gradient=y_grad) + return x.grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1), expressed for more memory-efficient - backprop as (x-1) * torch.sigmoid(x - 1) + torch.sigmoid(x - 1) + that we approximate closely with x * sigmoid(x-1). """ - x1 = x - 1.0 - s = torch.sigmoid(x1) - return (x1 * s) + s # (x-1) * s + s == x * s - + return DoubleSwishFunction.apply(x) def _test_exp_scale_swish(): From 8d17a05dd29ef78cd6063722f3f7bb2d92f8ad0e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Mar 2022 19:23:33 +0800 Subject: [PATCH 076/234] Reduce constraints from deriv-balancer in ConvModule. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 +++----- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 65a8431de..07fe934ae 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -861,7 +861,8 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.0, max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -873,8 +874,6 @@ class ConvolutionModule(nn.Module): bias=bias, ) - - self.deriv_balancer2 = DerivBalancer(channel_dim=1) # shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() @@ -904,13 +903,12 @@ class ConvolutionModule(nn.Module): # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = self.deriv_balancer1(x) + x = self.deriv_balancer(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index a0395a398..f2d89b099 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a23010fc1066a791966b0244831f3bb744751587 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Mar 2022 23:04:51 +0800 Subject: [PATCH 077/234] Add warmup mode --- .../ASR/transducer_stateless/conformer.py | 47 +++++++------------ .../transducer_stateless/encoder_interface.py | 4 +- .../ASR/transducer_stateless/model.py | 3 +- .../ASR/transducer_stateless/train.py | 11 +++-- 4 files changed, 29 insertions(+), 36 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 07fe934ae..b68aced9f 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -88,7 +88,7 @@ class Conformer(Transformer): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -112,7 +112,8 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask, + warmup_mode=warmup_mode) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -258,7 +259,6 @@ class ConformerEncoder(nn.Module): self.num_layers = num_layers num_channels = encoder_layer.d_model self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0) @@ -269,6 +269,7 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup_mode: bool = False ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -300,7 +301,7 @@ class ConformerEncoder(nn.Module): if i in self.aux_layers: outputs.append(output) - output = self.combiner(outputs) + output = self.combiner(outputs, warmup_mode) return output @@ -946,17 +947,12 @@ class RandomCombine(torch.nn.Module): is a random combination of all the inputs; but which in test time will be just the last input. - All but the last input will have a linear transform before we - randomly combine them; these linear transforms will be initialzed - to the identity transform. - The idea is that the list of Tensors will be a list of outputs of multiple conformer layers. This has a similar effect as iterated loss. (See: DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER NETWORKS). """ def __init__(self, num_inputs: int, - num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0) -> None: @@ -965,7 +961,6 @@ class RandomCombine(torch.nn.Module): num_inputs: The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing continuous layer weights. @@ -991,8 +986,6 @@ class RandomCombine(torch.nn.Module): assert pure_prob >= 0 and pure_prob <= 1 assert final_weight > 0 and final_weight < 1 assert num_inputs >= 1 - self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1)]) self.num_inputs = num_inputs self.final_weight = final_weight @@ -1000,14 +993,10 @@ class RandomCombine(torch.nn.Module): self.stddev= stddev self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - self._reset_parameters() - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) - def forward(self, inputs: Sequence[Tensor]) -> Tensor: + def forward(self, inputs: Sequence[Tensor], + warmup_mode: bool) -> Tensor: """ Forward function. Args: @@ -1019,24 +1008,18 @@ class RandomCombine(torch.nn.Module): """ num_inputs = self.num_inputs assert len(inputs) == num_inputs - if not self.training: + if not (self.training and warmup_mode): return inputs[-1] # Shape of weights: (*, num_inputs) num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) - mod_inputs.append(inputs[num_inputs - 1]) - - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) # weights: (num_frames, num_inputs) weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, @@ -1118,12 +1101,14 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") num_inputs = 3 num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, - final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) + m = RandomCombine(num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev) x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - y = m(x) + y = m(x, True) assert y.shape == x[0].shape assert torch.allclose(y, x[0]) # .. since actually all ones. diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index 257facce4..b295ce94b 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ import torch.nn as nn class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -32,6 +32,8 @@ class EncoderInterface(nn.Module): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup_mode: for training only, if true then train in + "warmup mode" (use this for the first few thousand minibatches). Returns: Return a tuple containing two tensors: - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 17b5f63e5..a45f0e295 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -62,6 +62,7 @@ class Transducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, y: k2.RaggedTensor, + warmup_mode: bool = False ) -> torch.Tensor: """ Args: @@ -82,7 +83,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) + encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index f2d89b099..6c318c242 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -203,6 +203,7 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 + "warmup_minibatches": 3000, # use warmup mode for 3k minibatches. # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -360,6 +361,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + is_warmup_mode: bool = False ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -391,7 +393,8 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - loss = model(x=feature, x_lens=feature_lens, y=y) + loss = model(x=feature, x_lens=feature_lens, y=y, + warmup_mode=is_warmup_mode) assert loss.requires_grad == is_training @@ -423,6 +426,7 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, + is_warmup_mode=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -484,6 +488,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + is_warmup_mode=(params.batch_idx_train Date: Tue, 15 Mar 2022 13:10:35 +0800 Subject: [PATCH 078/234] Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 2 +- egs/librispeech/ASR/transducer_stateless/conformer.py | 10 ++++++---- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 1e31c0a20..7c2b1ec04 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -527,7 +527,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, min_positive: float = 0.05, - max_positive: float = 0.95, + max_positive: float = 1.0, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index b68aced9f..54729652b 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -862,8 +862,7 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer = DerivBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.0, max_positive=1.0) + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) self.depthwise_conv = ScaledConv1d( channels, @@ -875,7 +874,9 @@ class ConvolutionModule(nn.Module): bias=bias, ) - # shape: (channels, 1), broadcasts with (batch, channel, time). + self.deriv_balancer2 = DerivBalancer(channel_dim=1) + + # Shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() self.pointwise_conv2 = ScaledConv1d( @@ -904,12 +905,13 @@ class ConvolutionModule(nn.Module): # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = self.deriv_balancer(x) + x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) + x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 6c318c242..6408290b4 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 21ebd356e78b82a93485554a81402a2149874eb4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 13:49:15 +0800 Subject: [PATCH 079/234] Add some extra info to diagnostics --- .../ASR/transducer_stateless/diagnostics.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 8ea35582a..238c50def 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -79,7 +79,7 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim options: options object sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value, + stats_type: either "abs" or "positive" or "eigs" or "value", imdictates the type of stats we accumulate, abs is mean absolute value, "positive" is proportion of positive to nonnegative values, "eigs" @@ -129,12 +129,23 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], percentiles.append(stats[index].item()) percentiles = [ '%.2g' % x for x in percentiles ] percentiles = ' '.join(percentiles) - return f'percentiles: [{percentiles}]' + ans = f'percentiles: [{percentiles}]' else: - stats = stats.tolist() - stats = [ '%.2g' % x for x in stats ] - stats = '[' + ' '.join(stats) + ']' - return stats + ans = stats.tolist() + ans = [ '%.2g' % x for x in ans ] + ans = '[' + ' '.join(ans) + ']' + if stats_type == "value": + norm = (stats ** 2).sum().sqrt().item() + mean_abs = stats.abs().mean().item() + # This norm is useful because it is strictly less than the largest + # sqrt(eigenvalue) of the variance, which we print out, and shows, + # speaking in an approximate way, how much of that largest eigenvalue + # can be attributed to the mean of the distribution. + ans += f', norm={norm:.2g}, mean_abs={mean_abs:.2g}' + else: + mean = stats.mean().item() + ans += f', mean={mean:.2g}' + return ans From 1962fe298b713a673cc4fd99c20e1deab45e2560 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 14:35:15 +0800 Subject: [PATCH 080/234] Add deriv-balancer at output of embedding. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 3 +++ egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 7c2b1ec04..35de71e43 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -57,6 +57,8 @@ class Conv2dSubsampling(nn.Module): ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) self.out_norm = BasicNorm(odim) + # constrain mean of output to be close to zero. + self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) self._reset_parameters() def _reset_parameters(self): @@ -84,6 +86,7 @@ class Conv2dSubsampling(nn.Module): x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) x = self.out_norm(x) + x = self.out_balancer(x) return x diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 6408290b4..488de3ccc 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From b2abcd721aae06deff49e7535141b9bd58bdf01a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 16:38:19 +0800 Subject: [PATCH 081/234] Add more stats. --- .../ASR/transducer_stateless/diagnostics.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py index 238c50def..7fd83d56b 100644 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ b/egs/librispeech/ASR/transducer_stateless/diagnostics.py @@ -135,16 +135,18 @@ def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], ans = [ '%.2g' % x for x in ans ] ans = '[' + ' '.join(ans) + ']' if stats_type == "value": - norm = (stats ** 2).sum().sqrt().item() - mean_abs = stats.abs().mean().item() # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - ans += f', norm={norm:.2g}, mean_abs={mean_abs:.2g}' + norm = (stats ** 2).sum().sqrt().item() + mean = stats.mean().item() + rms = (stats ** 2).mean().sqrt().item() + ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' else: mean = stats.mean().item() - ans += f', mean={mean:.2g}' + rms = (stats ** 2).mean().sqrt().item() + ans += f', mean={mean:.2g}, rms={rms:.2g}' return ans From fc873cc50d7e5a72344b0f081e93802acb441a73 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 17:00:17 +0800 Subject: [PATCH 082/234] Make epsilon in BasicNorm learnable, optionally. --- .../ASR/conformer_ctc/subsampling.py | 44 +++++++++++-------- .../ASR/transducer_stateless/conformer.py | 3 +- .../ASR/transducer_stateless/train.py | 2 +- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 35de71e43..78fcac664 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -56,7 +56,10 @@ class Conv2dSubsampling(nn.Module): DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) - self.out_norm = BasicNorm(odim) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(odim, learn_eps=False) # constrain mean of output to be close to zero. self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) self._reset_parameters() @@ -361,42 +364,45 @@ class BasicNorm(torch.nn.Module): So the idea is to introduce this large constant value as an explicit parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. - - We also introduce a learned scaling factor on the output; and we - remove the subtracting-the-mean aspect of LayerNorm (which anyway, is not - that useful unless the LayerNorm immediately follows a nonlinearity). - + doesn't have to do this trick. We make the "eps" learnable. Args: + num_channels: the number of channels, e.g. 512. channel_dim: the axis/dimension corresponding to the channel, interprted as an offset from the input's ndim if negative. shis is NOT the num_channels; it should typically be one of {-2, -1, 0, 1, 2, 3}. - initial_eps: the initial "epsilon" that we add as ballast in: + eps: the initial "epsilon" that we add as ballast in: scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with normal LayerNorm. - - speed: a scaling factor that can be interpreted as scaling the learning - rate for this module. CAUTION: the default value of 10.0 intended to be - used with Adam or amsgrad-type optimizers, e.g. Adam or Noam. - If you are using SGD you would probably have to set `speed` to - a value less than one, or the training would be unstable. + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_speed: a constant that determines how fast "eps" learns; + with Adam and variants, this should probably be >= 1, + e.g. 5.0. For SGD and variants, probably a value less than one, + like 0.1, would be suitable, to prevent instability. """ def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25): + eps: float = 0.25, + learn_eps: bool = True, + eps_speed: float = 5.0): super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.eps = eps + self.eps_speed = eps_speed + if learn_eps: + self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + else: + self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + (self.eps * self.eps_speed).exp()) ** -0.5 return x * scales diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 54729652b..8b229a234 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -1129,4 +1129,5 @@ if __name__ == '__main__': seq_len = 20 # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64)) + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup_mode=True) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 488de3ccc..2af306f94 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 261d7602a77ef46626454dce7b7d70b69c79226e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 15 Mar 2022 23:46:53 +0800 Subject: [PATCH 083/234] Draft of 0mean changes.. --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- .../ASR/transducer_stateless/conformer.py | 13 +++++++++---- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 78fcac664..50a9db41a 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -60,8 +60,8 @@ class Conv2dSubsampling(nn.Module): # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain mean of output to be close to zero. - self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.4, max_positive=0.6) + # constrain median of output to be close to zero. + self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) self._reset_parameters() def _reset_parameters(self): @@ -536,7 +536,7 @@ class DerivBalancer(torch.nn.Module): """ def __init__(self, channel_dim: int, min_positive: float = 0.05, - max_positive: float = 1.0, + max_positive: float = 0.95, max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8b229a234..cc1ae53a1 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -88,7 +88,7 @@ class Conformer(Transformer): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -179,6 +179,9 @@ class ConformerEncoderLayer(nn.Module): self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.dropout = nn.Dropout(dropout) @@ -227,7 +230,7 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.norm_final(self.pre_norm_final(src)) + src = self.balancer(self.norm_final(self.pre_norm_final(src))) return src @@ -862,7 +865,8 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0) + self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -874,7 +878,8 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.deriv_balancer2 = DerivBalancer(channel_dim=1) + self.deriv_balancer2 = DerivBalancer(channel_dim=1, + min_positive=0.05, max_positive=1.0) # Shape: (channels, 1), broadcasts with (batch, channel, time). self.activation = SwishOffset() diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2af306f94..41fdb4ef3 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv3warmup_embed_scale", + default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 633213424d24de73c09170d68b138ef830ed3cbd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:42:59 +0800 Subject: [PATCH 084/234] Rework of initialization --- .../ASR/conformer_ctc/subsampling.py | 70 ++++++++++++++++--- .../ASR/transducer_stateless/conformer.py | 16 ++--- .../ASR/transducer_stateless/decoder.py | 64 +++-------------- .../ASR/transducer_stateless/train.py | 3 +- 4 files changed, 78 insertions(+), 75 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 50a9db41a..5e44c5b29 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -62,13 +62,6 @@ class Conv2dSubsampling(nn.Module): self.out_norm = BasicNorm(odim, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) - self._reset_parameters() - - def _reset_parameters(self): - # init weights with smaller than default variance, because otherwise - # they learn too slowly in relative terms (assuming we're training with adam). - nn.init.normal_(self.conv[0].weight, std=0.05) - nn.init.constant_(self.conv[0].bias, 0.0) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -406,8 +399,36 @@ class BasicNorm(torch.nn.Module): return x * scales + + class ScaledLinear(nn.Linear): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * (self.weight_scale * self.scale_speed).exp() + bias = self.bias * (self.bias_scale * self.scale_speed).exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + scale_speed: a factor that affects how fast the weight_scale + and bias_scale learn; this value is suitable for Adam-type + optimizers. + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + + Note: it uses the default initialization for the weight and bias, + inherited from nn.Linear. For modules with small fan-in, this + may be larger than optimal. + """ + def __init__(self, *args, + scale_speed: float = 5.0, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = (torch.tensor(initial_scale).log() / scale_speed) self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -417,6 +438,17 @@ class ScaledLinear(nn.Linear): else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self): + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -425,7 +457,6 @@ class ScaledLinear(nn.Linear): return (None if self.bias is None else self.bias * (self.bias_scale * self.scale_speed).exp()) - def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) @@ -442,6 +473,17 @@ class ScaledConv1d(nn.Conv1d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -471,6 +513,16 @@ class ScaledConv2d(nn.Conv2d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + nn.init.normal_(self.weight, std=0.05) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) def get_weight(self): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cc1ae53a1..0b89fdcd2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( @@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ScaledLinear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -434,7 +434,6 @@ class RelPositionMultiheadAttention(nn.Module): self.scale_speed = scale_speed self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() def _pos_bias_u(self): @@ -444,12 +443,8 @@ class RelPositionMultiheadAttention(nn.Module): return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) + nn.init.normal_(self.pos_bias_u, std=0.05) + nn.init.normal_(self.pos_bias_v, std=0.05) def forward( self, @@ -891,7 +886,6 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, - initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index bc4bcb3f6..838b6794d 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -183,7 +183,7 @@ class ScaledEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None, + sparse: bool = False, scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings @@ -198,19 +198,18 @@ class ScaledEmbedding(nn.Module): self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed) - - if _weight is None: - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - else: - assert list(_weight.shape) == [num_embeddings, embedding_dim], \ - 'Shape of weight does not match num_embeddings and embedding_dim' - self.weight = nn.Parameter(_weight) + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=self.embedding_dim**-0.5) + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) @@ -228,7 +227,6 @@ class ScaledEmbedding(nn.Module): None, 2.0, # None, 2.0 relates to normalization self.scale_grad_by_freq, self.sparse) - def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: @@ -238,45 +236,3 @@ class ScaledEmbedding(nn.Module): if self.sparse is not False: s += ', sparse=True' return s.format(**self.__dict__) - - @classmethod - def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, - max_norm=None, norm_type=2., scale_grad_by_freq=False, - sparse=False): - r"""Creates Embedding instance from given 2-dimensional FloatTensor. - - Args: - embeddings (Tensor): FloatTensor containing weights for the Embedding. - First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. - freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. - Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` - padding_idx (int, optional): See module initialization documentation. - max_norm (float, optional): See module initialization documentation. - norm_type (float, optional): See module initialization documentation. Default ``2``. - scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. - sparse (bool, optional): See module initialization documentation. - - Examples:: - - >>> # FloatTensor containing pretrained weights - >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) - >>> embedding = nn.Embedding.from_pretrained(weight) - >>> # Get embeddings for index 1 - >>> input = torch.LongTensor([1]) - >>> embedding(input) - tensor([[ 4.0000, 5.1000, 6.3000]]) - """ - assert embeddings.dim() == 2, \ - 'Embeddings parameter is expected to be 2-dimensional' - rows, cols = embeddings.shape - embedding = cls( - num_embeddings=rows, - embedding_dim=cols, - _weight=embeddings, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) - embedding.weight.requires_grad = not freeze - return embedding diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 41fdb4ef3..8f2157715 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,8 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean", + # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. + default="transducer_stateless/randcombine1_expscale3_rework2d" help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From a783b9646729954623c37b932431ad0df0c253e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:43:44 +0800 Subject: [PATCH 085/234] Fix typo --- egs/librispeech/ASR/transducer_stateless/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 8f2157715..1190522e7 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -111,7 +111,7 @@ def get_parser(): "--exp-dir", type=str, # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. - default="transducer_stateless/randcombine1_expscale3_rework2d" + default="transducer_stateless/randcombine1_expscale3_rework2d", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 00be56c7a0ef956a7790598e73cf19e9ce6086cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 12:49:00 +0800 Subject: [PATCH 086/234] Remove dead code --- .../ASR/transducer_stateless/conformer.py | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 0b89fdcd2..cafc04ed1 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -876,8 +876,7 @@ class ConvolutionModule(nn.Module): self.deriv_balancer2 = DerivBalancer(channel_dim=1, min_positive=0.05, max_positive=1.0) - # Shape: (channels, 1), broadcasts with (batch, channel, time). - self.activation = SwishOffset() + self.activation = DoubleSwish() self.pointwise_conv2 = ScaledConv1d( channels, @@ -918,24 +917,6 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - -class SwishOffset(torch.nn.Module): - """Construct an SwishOffset object.""" - def __init__(self, offset: float = -1.0) -> None: - super(SwishOffset, self).__init__() - self.offset = offset - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x + self.offset) - - class Identity(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: return x From 0e9cad3f1f62abda43c6b218917525142c32b3d3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 14:42:53 +0800 Subject: [PATCH 087/234] Modifying initialization from normal->uniform; add initial_scale when initializing --- .../ASR/conformer_ctc/subsampling.py | 17 +++++++++++------ .../ASR/transducer_stateless/conformer.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 5e44c5b29..6cc90c8a1 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -441,15 +441,16 @@ class ScaledLinear(nn.Linear): self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] + fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) - def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -476,7 +477,9 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() @@ -516,10 +519,12 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cafc04ed1..0832d9385 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( @@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -885,6 +885,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, + initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: From 6561743d7b454111582011936beb0aa09f8fa161 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 14:55:17 +0800 Subject: [PATCH 088/234] bug fix re sqrt --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 6cc90c8a1..7c7d0ee6c 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -442,7 +442,7 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self): std = 0.05 - a = math.sqrt(3) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) @@ -478,7 +478,7 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self): std = 0.05 - a = math.sqrt(3) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) @@ -520,7 +520,7 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self): std = 0.05 - a = math.sqrt(3) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) From c82db4184a395177f4c2a79f1f20d7d3508777b2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 15:50:11 +0800 Subject: [PATCH 089/234] Remove xscale from pos_embedding --- egs/librispeech/ASR/conformer_ctc/subsampling.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 7c7d0ee6c..867ababf2 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -449,7 +449,7 @@ class ScaledLinear(nn.Linear): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -485,7 +485,7 @@ class ScaledConv1d(nn.Conv1d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): @@ -527,7 +527,7 @@ class ScaledConv2d(nn.Conv2d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) def get_weight(self): diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 0832d9385..b14e83780 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -327,7 +327,6 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -379,7 +378,6 @@ class RelPositionalEncoding(torch.nn.Module): """ self.extend_pe(x) - x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 From dfc75752c40c931eb63385e793d1ababf0e02489 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 18:06:01 +0800 Subject: [PATCH 090/234] Remove some dead code. --- .../ASR/conformer_ctc/subsampling.py | 160 ------------------ .../ASR/transducer_stateless/conformer.py | 2 +- 2 files changed, 1 insertion(+), 161 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 867ababf2..500cacca8 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -174,130 +174,6 @@ class VggSubsampling(nn.Module): return x -class PeLUFunction(torch.autograd.Function): - """ - Computes PeLU function (PeLUFunction.apply(x, cutoff, alpha)). - The function is: - x.relu() + alpha * (cutoff - x).relu() - E.g. consider cutoff = -1, alpha = 0.01. This will tend to prevent die-off - of neurons. - """ - @staticmethod - def forward(ctx, x: Tensor, cutoff: float, alpha: float) -> Tensor: - mask1 = (x >= 0) # >=, so there is deriv if x == 0. - p = cutoff - x - mask2 = (p >= 0) - ctx.save_for_backward(mask1, mask2) - ctx.alpha = alpha - return x.relu() + alpha * p.relu() - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None]: - mask1, mask2 = ctx.saved_tensors - return mask1 * ans_grad - (ctx.alpha * mask2) * ans_grad, None, None - - - -class PeLU(torch.nn.Module): - def __init__(self, cutoff: float = -1.0, alpha: float = 0.01) -> None: - super(PeLU, self).__init__() - self.cutoff = cutoff - self.alpha = alpha - def forward(self, x: Tensor) -> Tensor: - return PeLUFunction.apply(x, self.cutoff, self.alpha) - -class ExpScale(torch.nn.Module): - def __init__(self, *shape, speed: float = 1.0, initial_scale: float = 1.0): - super(ExpScale, self).__init__() - scale = torch.tensor(initial_scale) - scale = scale.log() / speed - self.scale = nn.Parameter(scale.detach()) - self.speed = speed - - def forward(self, x: Tensor) -> Tensor: - return x * (self.scale * self.speed).exp() - - - -def _exp_scale_swish(x: Tensor, scale: Tensor, speed: float) -> Tensor: - # double-swish, implemented/approximated as offset-swish - x = (x * torch.sigmoid(x - 1.0)) - x = x * (scale * speed).exp() - return x - -class SwishExpScaleFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_swish(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_swish(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - - -class SwishExpScale(torch.nn.Module): - # combines ExpScale and a Swish (actually the ExpScale is after the Swish). - # caution: need to specify name for speed, e.g. SwishExpScale(50, speed=4.0) - # - def __init__(self, *shape, speed: float = 1.0): - super(SwishExpScale, self).__init__() - - initial_log_scale = torch.zeros(()).detach() - self.scale = nn.Parameter(initial_log_scale) - self.speed = speed - - def forward(self, x: Tensor) -> Tensor: - return SwishExpScaleFunction.apply(x, self.scale, self.speed) - # x = (x * torch.sigmoid(x)) - # x = (x * torch.sigmoid(x)) - # x = x * (self.scale * self.speed).exp() - # return x - - - -def _exp_scale_relu(x: Tensor, scale: Tensor, speed: float) -> Tensor: - return (x * (scale * speed).exp()).relu() - - - - -class ExpScaleReluFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, scale: Tensor, speed: float) -> Tensor: - ctx.save_for_backward(x.detach(), scale.detach()) - ctx.speed = speed - return _exp_scale_relu(x, scale, speed) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - x, scale = ctx.saved_tensors - x.requires_grad = True - scale.requires_grad = True - with torch.enable_grad(): - y = _exp_scale_relu(x, scale, ctx.speed) - y.backward(gradient=y_grad) - return x.grad, scale.grad, None - -class ExpScaleRelu(torch.nn.Module): - # combines ExpScale and Relu. - # caution: need to specify name for speed, e.g. ExpScaleRelu(50, speed=4.0) - def __init__(self, *shape, speed: float = 1.0): - super(ExpScaleRelu, self).__init__() - self.scale = nn.Parameter(torch.zeros(*shape)) - self.speed = speed - - def forward(self, x: Tensor) -> Tensor: - return ExpScaleReluFunction.apply(x, self.scale, self.speed) - # return (x * torch.sigmoid(x)) * (self.scale * self.speed).exp() - # return x * (self.scale * self.speed).exp() - @@ -639,40 +515,6 @@ class DoubleSwish(torch.nn.Module): """ return DoubleSwishFunction.apply(x) -def _test_exp_scale_swish(): - - x1 = torch.randn(50, 60).detach() - x2 = x1.detach() - - m1 = SwishExpScale(50, 1, speed=4.0) - m2 = torch.nn.Sequential(DoubleSwish(), ExpScale(50, 1, speed=4.0)) - x1.requires_grad = True - x2.requires_grad = True - - y1 = m1(x1) - y2 = m2(x2) - assert torch.allclose(y1, y2, atol=1e-05) - y1.sum().backward() - y2.sum().backward() - assert torch.allclose(x1.grad, x2.grad, atol=1e-05) - -def _test_exp_scale_relu(): - - x1 = torch.randn(50, 60).detach() - x2 = x1.detach() - - m1 = ExpScaleRelu(50, 1, speed=4.0) - m2 = torch.nn.Sequential(nn.ReLU(), ExpScale(50, 1, speed=4.0)) - x1.requires_grad = True - x2.requires_grad = True - - y1 = m1(x1) - y2 = m2(x2) - assert torch.allclose(y1, y2) - y1.sum().backward() - y2.sum().backward() - assert torch.allclose(x1.grad, x2.grad) - def _test_deriv_balancer_sign(): @@ -737,6 +579,4 @@ def _test_basic_norm(): if __name__ == '__main__': _test_deriv_balancer_sign() _test_deriv_balancer_magnitude() - _test_exp_scale_swish() - _test_exp_scale_relu() _test_basic_norm() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index b14e83780..8de02628d 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import PeLU, ExpScale, DoubleSwish, SwishExpScale, ExpScaleRelu, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn From e838c192ef05b7a4a3659672cfa54ef37f23f57b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 19:27:45 +0800 Subject: [PATCH 091/234] Cosmetic changes/renaming things --- .../ASR/conformer_ctc/subsampling.py | 59 ++++++++----------- .../ASR/transducer_stateless/conformer.py | 20 ++++--- 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 500cacca8..0a39b0f33 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -47,12 +47,12 @@ class Conv2dSubsampling(nn.Module): ScaledConv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1), + ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - DerivBalancer(channel_dim=1), + ActivationBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -61,7 +61,9 @@ class Conv2dSubsampling(nn.Module): # needed. self.out_norm = BasicNorm(odim, learn_eps=False) # constrain median of output to be close to zero. - self.out_balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -177,7 +179,7 @@ class VggSubsampling(nn.Module): -class DerivBalancerFunction(torch.autograd.Function): +class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, channel_dim: int, @@ -428,44 +430,33 @@ class ScaledConv2d(nn.Conv2d): -class DerivBalancer(torch.nn.Module): +class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for each channel, that it is positive at least a proportion `threshold` of the time. It does this by multiplying negative derivative values by up to (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 0 at the threshold to those extremal values when none + interpolated from 1 at the threshold to those extremal values when none of the inputs are positive. - When all grads are zero for a channel, this - module sets all the input derivatives for that channel to -epsilon; the - idea is to bring completely dead neurons back to life this way. Args: - channel_dim: the dimension/axi corresponding to the channel, e.g. + channel_dim: the dimension/axis corresponding to the channel, e.g. -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. min_positive: the minimum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_positive: the maximum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives, + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.01]. - zero: we use this value in the comparison (x > 0), i.e. we actually use - (x > zero). The reason for using a threshold slightly greater - than zero is that it will tend to prevent situations where the - inputs shrink close to zero and the nonlinearity (e.g. swish) - behaves like a linear function and we learn nothing. + values in the range [0.98..1.02]. min_abs: the minimum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent - this. This is to prevent a failure mode where the activations - become so small that the nonlinearity effectively becomes linear, - which makes the module useless and it gets even smaller - to try to "turn it off" completely. + this. max_abs: the maximum average-absolute-value per channel, which we allow, before we start to modify the derivatives to prevent - this. This is to prevent the possibility of activations getting - out of floating point numerical range (especially in half precision). + this. """ def __init__(self, channel_dim: int, min_positive: float = 0.05, @@ -473,7 +464,7 @@ class DerivBalancer(torch.nn.Module): max_factor: float = 0.01, min_abs: float = 0.2, max_abs: float = 100.0): - super(DerivBalancer, self).__init__() + super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive self.max_positive = max_positive @@ -482,10 +473,10 @@ class DerivBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return DerivBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) def _double_swish(x: Tensor) -> Tensor: @@ -524,8 +515,8 @@ def _test_deriv_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -542,10 +533,10 @@ def _test_deriv_balancer_magnitude(): x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = DerivBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 8de02628d..6278734e5 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, DerivBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn @@ -159,7 +159,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1), + ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -167,7 +167,7 @@ class ConformerEncoderLayer(nn.Module): self.feed_forward_macaron = nn.Sequential( ScaledLinear(d_model, dim_feedforward), - DerivBalancer(channel_dim=-1), + ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), @@ -180,7 +180,9 @@ class ConformerEncoderLayer(nn.Module): self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = DerivBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55) + self.balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) self.dropout = nn.Dropout(dropout) @@ -858,8 +860,9 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = DerivBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, max_positive=1.0) + self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, + max_positive=1.0) self.depthwise_conv = ScaledConv1d( channels, @@ -871,8 +874,9 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.deriv_balancer2 = DerivBalancer(channel_dim=1, - min_positive=0.05, max_positive=1.0) + self.deriv_balancer2 = ActivationBalancer(channel_dim=1, + min_positive=0.05, + max_positive=1.0) self.activation = DoubleSwish() From 1f3a15f3c45814daefbf399d4a181b91af7cd8ea Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 22:14:30 +0800 Subject: [PATCH 092/234] Start adding some files.. --- egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py | 0 .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 1 + egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py | 1 + 3 files changed, 2 insertions(+) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py new file mode 120000 index 000000000..07f39b451 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -0,0 +1 @@ +../transducer/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py new file mode 120000 index 000000000..227d2247c --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/beam_search.py \ No newline at end of file From cc8e4412f7954620224a5b2f4deb80e029ce7c36 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 22:16:40 +0800 Subject: [PATCH 093/234] Add more files.. --- .../pruned_transducer_stateless2/conformer.py | 1115 +++++++++++++++++ .../pruned_transducer_stateless2/decode.py | 1 + 2 files changed, 1116 insertions(+) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py new file mode 100644 index 000000000..bf96b41f9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import Optional, Tuple, Sequence +from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d + +import torch +from torch import Tensor, nn +from transformer import Transformer + +from icefall.utils import make_pad_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + output_dim (int): Number of output dimension + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + output_dim: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + aux_layer_period: int = 3 + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + output_dim=output_dim, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, + aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.normalize_before = normalize_before + + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + mask = make_pad_mask(lengths) + + x = self.encoder(x, pos_emb, src_key_padding_mask=mask, + warmup_mode=warmup_mode) # (T, N, C) + + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.d_model = d_model + + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + + self.pre_norm_final = Identity() + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + self.dropout = nn.Dropout(dropout) + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout(self.conv_module(src)) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.balancer(self.norm_final(self.pre_norm_final(src))) + + return src + + +class ConformerEncoder(nn.Module): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + aux_layers: Sequence[int]) -> None: + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.aux_layers = set(aux_layers + [num_layers - 1]) + assert num_layers - 1 not in aux_layers + self.num_layers = num_layers + num_channels = encoder_layer.d_model + self.combiner = RandomCombine(num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + warmup_mode: bool = False + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + outputs = [] + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + if i in self.aux_layers: + outputs.append(output) + + output = self.combiner(outputs, warmup_mode) + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + scale_speed: float = 5.0 + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.scale_speed = scale_speed + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + + def _pos_bias_v(self): + return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.05) + nn.init.normal_(self.pos_bias_v, std=0.05) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, + min_positive=0.05, + max_positive=1.0) + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer(channel_dim=1, + min_positive=0.05, + max_positive=1.0) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25 + ) + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Identity(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return x + + +class RandomCombine(torch.nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + def __init__(self, num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0) -> None: + """ + Args: + num_inputs: The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, + or combinations of layers, to use, is conceptually as follows. + With probability `pure_prob`: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super(RandomCombine, self).__init__() + assert pure_prob >= 0 and pure_prob <= 1 + assert final_weight > 0 and final_weight < 1 + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev= stddev + + self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() + + + def forward(self, inputs: Sequence[Tensor], + warmup_mode: bool) -> Tensor: + """ + Forward function. + Args: + inputs: a list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + a Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not (self.training and warmup_mode): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, + num_frames) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + + def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: + """ + Return a tensor of random weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), such that + ans.sum(dim=1) is all ones. + + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) + + def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with + exactly one weight equal to 1.0 on each frame. + """ + + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + + indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, + final, nonfinal) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) + return ans + + + def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that + sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. + """ + logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev + logprobs[:,-1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") + num_inputs = 3 + num_channels = 50 + m = RandomCombine(num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev) + + x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] + + y = m(x, True) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +if __name__ == '__main__': + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c(torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup_mode=True) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py new file mode 120000 index 000000000..c1125a9ba --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decode.py \ No newline at end of file From e3ad8f63e73e8cc6a1970d451285def55e97a776 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 22:22:10 +0800 Subject: [PATCH 094/234] update decode.py file type --- .../pruned_transducer_stateless2/decode.py | 424 +++++++++++++++++- 1 file changed, 423 insertions(+), 1 deletion(-) mode change 120000 => 100755 egs/librispeech/ASR/pruned_transducer_stateless2/decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py deleted file mode 120000 index c1125a9ba..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless/decode.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py new file mode 100755 index 000000000..86ec6172f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import beam_search, greedy_search, modified_beam_search +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() From 11bea4513eff9b478df4ad02009fd0f491dd7ca5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 11:17:52 +0800 Subject: [PATCH 095/234] Add remaining files in pruned_transducer_stateless2 --- .../pruned_transducer_stateless2/conformer.py | 2 +- .../pruned_transducer_stateless2/decoder.py | 241 ++++++ .../encoder_interface.py | 1 + .../pruned_transducer_stateless2/export.py | 182 ++++ .../pruned_transducer_stateless2/joiner.py | 50 ++ .../ASR/pruned_transducer_stateless2/model.py | 170 ++++ .../pruned_transducer_stateless2/scaling.py | 418 +++++++++ .../subsampling.py | 176 ++++ .../ASR/pruned_transducer_stateless2/train.py | 810 ++++++++++++++++++ .../transformer.py | 418 +++++++++ .../transducer_stateless/encoder_interface.py | 4 +- 11 files changed, 2468 insertions(+), 4 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless2/export.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/model.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless2/train.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bf96b41f9..245af05e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -19,7 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d import torch from torch import Tensor, nn diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py new file mode 100644 index 000000000..7836ca999 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -0,0 +1,241 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Optional +from scaling import ScaledConv1d, ScaledLinear + + +class Decoder(nn.Module): + """This class modifies the stateless decoder from the following paper: + + RNN-transducer with stateless prediction network + https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 + + It removes the recurrent connection from the decoder, i.e., the prediction + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. + + TODO: Implement https://arxiv.org/pdf/2109.07513.pdf + """ + + def __init__( + self, + vocab_size: int, + embedding_dim: int, + blank_id: int, + context_size: int, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit including blank. + embedding_dim: + Dimension of the input embedding. + blank_id: + The ID of the blank symbol. + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + """ + super().__init__() + self.embedding = ScaledEmbedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + padding_idx=blank_id, + ) + self.blank_id = blank_id + + assert context_size >= 1, context_size + self.context_size = context_size + if context_size > 1: + self.conv = ScaledConv1d( + in_channels=embedding_dim, + out_channels=embedding_dim, + kernel_size=context_size, + padding=0, + groups=embedding_dim, + bias=False, + ) + self.output_linear = ScaledLinear(embedding_dim, vocab_size) + + + def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + need_pad: + True to left pad the input. Should be True during training. + False to not pad the input. Should be False during inference. + Returns: + Return a tensor of shape (N, U, embedding_dim). + """ + y = y.to(torch.int64) + embedding_out = self.embedding(y) + if self.context_size > 1: + embedding_out = embedding_out.permute(0, 2, 1) + if need_pad is True: + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) + else: + # During inference time, there is no need to do extra padding + # as we only need one output + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + embedding_out = self.output_linear(F.relu(embedding_out)) + return embedding_out + + + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale_speed = scale_speed + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = (self.scale * self.scale_speed).exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py new file mode 100755 index 000000000..7d2a07817 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless/export.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless/decode.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py new file mode 100644 index 000000000..61bfe8186 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -0,0 +1,50 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scaling import ScaledLinear + +class Joiner(nn.Module): + def __init__(self, input_dim: int, inner_dim: int, output_dim: int): + super().__init__() + + self.inner_linear = ScaledLinear(input_dim, inner_dim) + self.output_linear = ScaledLinear(inner_dim, output_dim) + + def forward( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, s_range, C). + decoder_out: + Output from the decoder. Its shape is (N, T, s_range, C). + Returns: + Return a tensor of shape (N, T, s_range, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.shape == decoder_out.shape + + logit = encoder_out + decoder_out + + logit = self.inner_linear(torch.tanh(logit)) + + output = self.output_linear(F.relu(logit)) + + return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py new file mode 100644 index 000000000..e83d18e3e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -0,0 +1,170 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, C) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, C). It should contain + one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, C) and (N, U, C). Its + output shape is (N, T, U, C). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup_mode: bool = False + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode=warmup_mode) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, C] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=decoder_out, + am=encoder_out, + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, C] + # lm_pruned : [B, T, prune_range, C] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=encoder_out, lm=decoder_out, ranges=ranges + ) + + # logits : [B, T, prune_range, C] + logits = self.joiner(am_pruned, lm_pruned) + + pruned_loss = k2.rnnt_loss_pruned( + logits=logits, + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py new file mode 100644 index 000000000..c8bc35fd1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -0,0 +1,418 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple + + + + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + ) -> Tensor: + if x.requires_grad: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + xgt0 = x > 0 + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) + factor = factor1 + factor2 + if isinstance(factor, float): + factor = torch.zeros_like(proportion_positive) + + mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) + + ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.max_factor = max_factor + ctx.sum_dims = sum_dims + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors + dtype = x_grad.dtype + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + + neg_delta_grad = x_grad.abs() * (factor + scale_factor) + return x_grad - neg_delta_grad, None, None, None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_speed: a constant that determines how fast "eps" learns; + with Adam and variants, this should probably be >= 1, + e.g. 5.0. For SGD and variants, probably a value less than one, + like 0.1, would be suitable, to prevent instability. + """ + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_speed: float = 5.0): + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.eps_speed = eps_speed + if learn_eps: + self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + else: + self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) + + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + (self.eps * self.eps_speed).exp()) ** -0.5 + return x * scales + + + + +class ScaledLinear(nn.Linear): + """ + A modified version of nn.Linear where the parameters are scaled before + use, via: + weight = self.weight * (self.weight_scale * self.scale_speed).exp() + bias = self.bias * (self.bias_scale * self.scale_speed).exp() + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + scale_speed: a factor that affects how fast the weight_scale + and bias_scale learn; this value is suitable for Adam-type + optimizers. + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + + Note: it uses the default initialization for the weight and bias, + inherited from nn.Linear. For modules with small fan-in, this + may be larger than optimal. + """ + def __init__(self, *args, + scale_speed: float = 5.0, + initial_scale: float = 1.0, + **kwargs): + super(ScaledLinear, self).__init__(*args, **kwargs) + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + self.scale_speed = scale_speed + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + + self._reset_parameters() # Overrides the reset_parameters in nn.Linear + + def _reset_parameters(self): + std = 0.05 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def forward(self, input: Tensor) -> Tensor: + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) + + +class ScaledConv1d(nn.Conv1d): + def __init__(self, *args, scale_speed = 5.0, + initial_scale=1.0, **kwargs): + super(ScaledConv1d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + std = 0.05 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + + +class ScaledConv2d(nn.Conv2d): + def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + super(ScaledConv2d, self).__init__(*args, **kwargs) + self.scale_speed = scale_speed + initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + self.weight_scale = nn.Parameter(initial_scale.clone().detach()) + if self.bias is not None: + self.bias_scale = nn.Parameter(initial_scale.clone().detach()) + else: + self.register_parameter('bias_scale', None) + self._reset_parameters() # Overrides the reset_parameters in base class + + def _reset_parameters(self): + std = 0.05 + a = (3 ** 0.5) * std + nn.init.uniform_(self.weight, -a, a) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() + scale = fan_in ** -0.5 # 1/sqrt(fan_in) + with torch.no_grad(): + self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + + + def get_weight(self): + return self.weight * (self.weight_scale * self.scale_speed).exp() + + def get_bias(self): + return (None if self.bias is None else + self.bias * (self.bias_scale * self.scale_speed).exp()) + + def _conv_forward(self, input, weight): + F = torch.nn.functional + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.get_weight()) + + + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + + Args: + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + min_abs: the minimum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + max_abs: the maximum average-absolute-value per channel, which + we allow, before we start to modify the derivatives to prevent + this. + """ + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): + super(ActivationBalancer, self).__init__() + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + + def forward(self, x: Tensor) -> Tensor: + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) + + +def _double_swish(x: Tensor) -> Tensor: + # double-swish, implemented/approximated as offset-swish + return x * torch.sigmoid(x - 1.0) + +class DoubleSwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + ctx.save_for_backward(x.detach()) + return _double_swish(x) + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + # TODO: can make this more efficient. + x, = ctx.saved_tensors + x.requires_grad = True + with torch.enable_grad(): + y = _double_swish(x) + y.backward(gradient=y_grad) + return x.grad + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + return DoubleSwishFunction.apply(x) + + + +def _test_activation_balancer_sign(): + channel_dim = 0 + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + +def _test_activation_balancer_magnitude(): + channel_dim = 0 + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + + + + +if __name__ == '__main__': + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py new file mode 100644 index 000000000..51b08e072 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -0,0 +1,176 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple +from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py new file mode 100755 index 000000000..51858448d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -0,0 +1,810 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 300 \ + --lr-factor 1.5 +""" + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from model import Transducer +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall import diagnostics + +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_stateless/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - attention_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 512, + # parameters for Noam + "warm_step": 30000, # For the 100h subset, use 8k + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.vocab_size, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup_mode: bool = False +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup_mode=warmup_mode, + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup_mode=(params.batch_idx_train < params.model_warm_step) + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + params.warm_step = 8000 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup_mode=True # may use slightly more memory + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py new file mode 100644 index 000000000..3fa847f4f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -0,0 +1,418 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear + +from icefall.utils import make_pad_mask + + +class Transformer(EncoderInterface): + def __init__( + self, + num_features: int, + output_dim: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + output_dim: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder. + num_encoder_layers: + Number of encoder layers. + dropout: + Dropout in encoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + """ + super().__init__() + + self.num_features = num_features + self.output_dim = output_dim + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) + ) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() + + mask = make_pad_mask(lengths) + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + logits = self.encoder_output_layer(x) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, lengths + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index b295ce94b..3d218dcd0 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ import torch.nn as nn class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -32,8 +32,6 @@ class EncoderInterface(nn.Module): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - warmup_mode: for training only, if true then train in - "warmup mode" (use this for the first few thousand minibatches). Returns: Return a tuple containing two tensors: - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) From 13db33ffa2dba26a528748979fa202b6949fc0e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 15:53:53 +0800 Subject: [PATCH 096/234] Fix diagnostics-getting code --- .../ASR/pruned_transducer_stateless2/train.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 51858448d..b7cd45334 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -115,7 +115,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -556,6 +556,9 @@ def train_one_epoch( optimizer.step() optimizer.zero_grad() + if params.print_diagnostics and batch_idx == 5: + return + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -665,7 +668,11 @@ def run(rank, world_size, args): if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + librispeech = LibriSpeechAsrDataModule(args) From acc0eda5b0b9b20b33ff1cdbb8bb467d6bc9fdbb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 16:09:35 +0800 Subject: [PATCH 097/234] Scale down pruned loss in warmup mode --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b7cd45334..f95d8e73c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -450,7 +450,9 @@ def compute_loss( lm_scale=params.lm_scale, warmup_mode=warmup_mode, ) - loss = params.simple_loss_scale * simple_loss + pruned_loss + loss = params.simple_loss_scale * simple_loss + if not warmup_mode: + loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0) assert loss.requires_grad == is_training From cbe6b175d1d17bd6e20e2970fba46758249fa11c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 16:46:59 +0800 Subject: [PATCH 098/234] Reduce warmup scale on pruned loss form 0.1 to 0.01. --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index f95d8e73c..f7eb15c01 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -452,7 +452,7 @@ def compute_loss( ) loss = params.simple_loss_scale * simple_loss if not warmup_mode: - loss = loss + pruned_loss * (0.1 if warmup_mode else 1.0) + loss = loss + (pruned_loss * 0.01 if warmup_mode else pruned_loss) assert loss.requires_grad == is_training From 6769087d702b3b8fed473e2da487772622be26c1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:31:25 +0800 Subject: [PATCH 099/234] Remove scale_speed, make swish deriv more efficient. --- .../pruned_transducer_stateless2/conformer.py | 6 +- .../pruned_transducer_stateless2/decoder.py | 138 +---------- .../pruned_transducer_stateless2/scaling.py | 222 ++++++++++++++---- 3 files changed, 181 insertions(+), 185 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 245af05e3..cb4652840 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -410,7 +410,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -430,16 +429,15 @@ class RelPositionMultiheadAttention(nn.Module): # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.scale_speed = scale_speed self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) self._reset_parameters() def _pos_bias_u(self): - return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() + return self.pos_bias_u * self.pos_bias_u_scale.exp() def _pos_bias_v(self): - return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() + return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: nn.init.normal_(self.pos_bias_u, std=0.05) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 7836ca999..47a519dc9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Optional -from scaling import ScaledConv1d, ScaledLinear +from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding class Decoder(nn.Module): @@ -103,139 +103,3 @@ class Decoder(nn.Module): embedding_out = embedding_out.permute(0, 2, 1) embedding_out = self.output_linear(F.relu(embedding_out)) return embedding_out - - - -class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - - - - def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - scale = (self.scale * self.scale_speed).exp() - if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale - else: - return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c8bc35fd1..f0e3fe148 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from torch import Tensor -from typing import Tuple +from typing import Tuple, Optional @@ -94,31 +94,25 @@ class BasicNorm(torch.nn.Module): to indicate the connection with conventional LayerNorm. learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. - eps_speed: a constant that determines how fast "eps" learns; - with Adam and variants, this should probably be >= 1, - e.g. 5.0. For SGD and variants, probably a value less than one, - like 0.1, would be suitable, to prevent instability. """ def __init__(self, num_channels: int, channel_dim: int = -1, # CAUTION: see documentation. eps: float = 0.25, - learn_eps: bool = True, - eps_speed: float = 5.0): + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim - self.eps_speed = eps_speed if learn_eps: - self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - (self.eps * self.eps_speed).exp()) ** -0.5 + self.eps.exp()) ** -0.5 return x * scales @@ -128,16 +122,13 @@ class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before use, via: - weight = self.weight * (self.weight_scale * self.scale_speed).exp() - bias = self.bias * (self.bias_scale * self.scale_speed).exp() + weight = self.weight * self.weight_scale.exp() + bias = self.bias * self.bias_scale.exp() Args: Accepts the standard args and kwargs that nn.Linear accepts e.g. in_features, out_features, bias=False. - scale_speed: a factor that affects how fast the weight_scale - and bias_scale learn; this value is suitable for Adam-type - optimizers. initial_scale: you can override this if you want to increase or decrease the initial magnitude of the module's output (affects the initialization of weight_scale and bias_scale). @@ -149,13 +140,11 @@ class ScaledLinear(nn.Linear): may be larger than optimal. """ def __init__(self, *args, - scale_speed: float = 5.0, initial_scale: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - self.scale_speed = scale_speed if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: @@ -172,14 +161,14 @@ class ScaledLinear(nn.Linear): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: return torch.nn.functional.linear(input, self.get_weight(), @@ -187,11 +176,10 @@ class ScaledLinear(nn.Linear): class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -208,15 +196,15 @@ class ScaledConv1d(nn.Conv1d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional @@ -230,10 +218,9 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): + def __init__(self, *args, initial_scale=1.0, **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) + initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) @@ -250,15 +237,15 @@ class ScaledConv2d(nn.Conv2d): fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) + self.weight_scale += torch.tensor(scale / std).log() def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() + return self.weight * self.weight_scale.exp() def get_bias(self): return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional @@ -323,6 +310,16 @@ class ActivationBalancer(torch.nn.Module): self.max_factor, self.min_abs, self.max_abs) +# deriv of double_swish: +# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally +# motivated by its similarity to swish(swish(x), +# where swish(x) = x *sigmoid(x)]. +# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) +# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). +# Now, s'(x) = s(x) * (1-s(x)). +# double_swish'(x) = x * s'(x) + s(x). +# = x * s(x) * (1-s(x)) + s(x). +# = double_swish(x) * (1-s(x)) + s(x) def _double_swish(x: Tensor) -> Tensor: # double-swish, implemented/approximated as offset-swish @@ -331,18 +328,16 @@ def _double_swish(x: Tensor) -> Tensor: class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x.detach()) - return _double_swish(x) + x = x.detach() + s = torch.sigmoid(x - 1.0) + y = x * s + ctx.save_for_backward(s, y) + return y @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - # TODO: can make this more efficient. - x, = ctx.saved_tensors - x.requires_grad = True - with torch.enable_grad(): - y = _double_swish(x) - y.backward(gradient=y_grad) - return x.grad + s, y = ctx.saved_tensors + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: @@ -353,6 +348,140 @@ class DoubleSwish(torch.nn.Module): + +class ScaledEmbedding(nn.Module): + r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + + This module is often used to store word embeddings and retrieve them using indices. + The input to the module is a list of indices, and the output is the corresponding + word embeddings. + + Args: + num_embeddings (int): size of the dictionary of embeddings + embedding_dim (int): the size of each embedding vector + padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` + (initialized to zeros) whenever it encounters the index. + max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` + is renormalized to have norm :attr:`max_norm`. + norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. + scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of + the words in the mini-batch. Default ``False``. + sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. + See Notes for more details regarding sparse gradients. + + Attributes: + weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) + initialized from :math:`\mathcal{N}(0, 1)` + + Shape: + - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract + - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` + + .. note:: + Keep in mind that only a limited number of optimizers support + sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), + :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) + + .. note:: + With :attr:`padding_idx` set, the embedding vector at + :attr:`padding_idx` is initialized to all zeros. However, note that this + vector can be modified afterwards, e.g., using a customized + initialization method, and thus changing the vector used to pad the + output. The gradient for this vector from :class:`~torch.nn.Embedding` + is always zero. + + Examples:: + + >>> # an Embedding module containing 10 tensors of size 3 + >>> embedding = nn.Embedding(10, 3) + >>> # a batch of 2 samples of 4 indices each + >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) + >>> embedding(input) + tensor([[[-0.0251, -1.6902, 0.7172], + [-0.6431, 0.0748, 0.6969], + [ 1.4970, 1.3448, -0.9685], + [-0.3677, -2.7265, -0.1685]], + + [[ 1.4970, 1.3448, -0.9685], + [ 0.4362, -0.4004, 0.9400], + [-0.6431, 0.0748, 0.6969], + [ 0.9124, -2.3616, 1.1151]]]) + + + >>> # example with padding_idx + >>> embedding = nn.Embedding(10, 3, padding_idx=0) + >>> input = torch.LongTensor([[0,2,0,5]]) + >>> embedding(input) + tensor([[[ 0.0000, 0.0000, 0.0000], + [ 0.1535, -2.0309, 0.9315], + [ 0.0000, 0.0000, 0.0000], + [-0.1655, 0.9897, 0.0635]]]) + """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] + + num_embeddings: int + embedding_dim: int + padding_idx: int + scale_grad_by_freq: bool + weight: Tensor + sparse: bool + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: + super(ScaledEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.scale_grad_by_freq = scale_grad_by_freq + + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.sparse = sparse + + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + + + + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, std=0.05) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log()) + + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + scale = self.scale.exp() + if input.numel() < self.num_embeddings: + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale + else: + return F.embedding( + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}, scale={scale}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + def _test_activation_balancer_sign(): channel_dim = 0 probs = torch.arange(0, 1, 0.01) @@ -409,10 +538,15 @@ def _test_basic_norm(): assert y_rms > 0.5 * x_rms - +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 0.5 + x.requires_grad = True + m = DoubleSwish() + torch.autograd.gradcheck(m, x) if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() + _test_double_swish_deriv() From ba3611cefd1af82ef343beec9daef9d2e795f3a0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:35:48 +0800 Subject: [PATCH 100/234] Cosmetic changes to swish --- .../pruned_transducer_stateless2/scaling.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f0e3fe148..d03bd0967 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -310,22 +310,22 @@ class ActivationBalancer(torch.nn.Module): self.max_factor, self.min_abs, self.max_abs) -# deriv of double_swish: -# double_swish(x) = x * torch.sigmoid(x-1) [this is a definition, originally -# motivated by its similarity to swish(swish(x), -# where swish(x) = x *sigmoid(x)]. -# double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) -# double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). -# Now, s'(x) = s(x) * (1-s(x)). -# double_swish'(x) = x * s'(x) + s(x). -# = x * s(x) * (1-s(x)) + s(x). -# = double_swish(x) * (1-s(x)) + s(x) - -def _double_swish(x: Tensor) -> Tensor: - # double-swish, implemented/approximated as offset-swish - return x * torch.sigmoid(x - 1.0) class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() From 2dfcd8f1176851be3a8dbff5c7abde0ef0793cf0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:38:36 +0800 Subject: [PATCH 101/234] Double warm_step --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index f7eb15c01..ae45db60f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -615,7 +615,7 @@ def run(rank, world_size, args): params.update(vars(args)) if params.full_libri is False: params.valid_interval = 800 - params.warm_step = 8000 + params.warm_step = 16000 fix_random_seed(params.seed) if world_size > 1: From c9f1aeb7d18eaa33c5d8b7f1fe7365ac9a0ff971 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 16:40:24 +0800 Subject: [PATCH 102/234] Fix bug with import --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d03bd0967..2d0331312 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -459,6 +459,7 @@ class ScaledEmbedding(nn.Module): self.weight[self.padding_idx].fill_(0) def forward(self, input: Tensor) -> Tensor: + F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: return F.embedding( From 188eada7ac8f761b130f4a3bdbbeb92e8160e38d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 21:28:34 +0800 Subject: [PATCH 103/234] Change initial std from 0.05 to 0.025. --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 2d0331312..d4aef5cdd 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -153,7 +153,7 @@ class ScaledLinear(nn.Linear): self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - std = 0.05 + std = 0.025 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -188,7 +188,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.05 + std = 0.025 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -229,7 +229,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.05 + std = 0.025 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: From 8cff994cd7da9880ca63de95212fb1bd7d0a2bc0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 18 Mar 2022 21:30:05 +0800 Subject: [PATCH 104/234] Set also scale for embedding to 0.025. --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d4aef5cdd..b358e5fa2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -451,8 +451,9 @@ class ScaledEmbedding(nn.Module): def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log()) + std = 0.025 + nn.init.normal_(self.weight, std=std) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) if self.padding_idx is not None: with torch.no_grad(): From 0ee2404ff09057205812dc0d6b39495192a87c80 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 19 Mar 2022 14:01:45 +0800 Subject: [PATCH 105/234] Remove logging code that broke with newer Lhotse; fix bug with pruned_loss --- .../ASR/pruned_transducer_stateless2/train.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ae45db60f..851822aae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -450,9 +450,8 @@ def compute_loss( lm_scale=params.lm_scale, warmup_mode=warmup_mode, ) - loss = params.simple_loss_scale * simple_loss - if not warmup_mode: - loss = loss + (pruned_loss * 0.01 if warmup_mode else pruned_loss) + loss = (params.simple_loss_scale * simple_loss + + (pruned_loss * 0.01 if warmup_mode else pruned_loss)) assert loss.requires_grad == is_training @@ -687,18 +686,8 @@ def run(rank, world_size, args): # Keep only utterances with duration between 1 second and 20 seconds return 1.0 <= c.duration <= 20.0 - num_in_total = len(train_cuts) - train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 - - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - train_dl = librispeech.train_dataloaders(train_cuts) valid_cuts = librispeech.dev_clean_cuts() From 05b5e78d8f2298cf6b4b757a620df099dfc0841d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 15:55:11 +0800 Subject: [PATCH 106/234] Add norm+balancer to VggSubsampling --- .../ASR/pruned_transducer_stateless2/subsampling.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index 51b08e072..c2da23adc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -158,6 +158,12 @@ class VggSubsampling(nn.Module): self.out = nn.Linear( block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim ) + self.out_norm = BasicNorm(odim, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -173,4 +179,6 @@ class VggSubsampling(nn.Module): x = self.layers(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x = self.out_norm(x) + x = self.out_balancer(x) return x From ccbf8ba0862347007fb6aed87fff6f152d1bc35f Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 16:51:48 +0800 Subject: [PATCH 107/234] Incorporate changes from master into pruned_transducer_stateless2. --- .../pruned_transducer_stateless2/decode.py | 173 ++++++++++++++---- .../pruned_transducer_stateless2/decoder.py | 1 + .../ASR/pruned_transducer_stateless2/train.py | 121 ++++++++++-- icefall/diagnostics.py | 9 +- 4 files changed, 253 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 86ec6172f..ad76411c0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -42,6 +42,17 @@ Usage: --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -49,16 +60,26 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + modified_beam_search, +) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -88,6 +109,17 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -110,6 +142,7 @@ def get_parser(): - greedy_search - beam_search - modified_beam_search + - fast_beam_search """, ) @@ -117,8 +150,35 @@ def get_parser(): "--beam-size", type=int, default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, help="""Used only when --decoding-method is - beam_search or modified_beam_search""", + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", ) parser.add_argument( @@ -144,6 +204,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -166,6 +227,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -184,36 +248,62 @@ def decode_one_batch( x=feature, x_lens=feature_lens ) hyps = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } else: - return {f"beam_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}": hyps} def decode_dataset( @@ -221,6 +311,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -233,6 +324,9 @@ def decode_dataset( The neural model. sp: The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -260,6 +354,7 @@ def decode_dataset( params=params, model=model, sp=sp, + decoding_graph=decoding_graph, batch=batch, ) @@ -340,12 +435,17 @@ def main(): assert params.decoding_method in ( "greedy_search", "beam_search", + "fast_beam_search", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "beam_search" in params.decoding_method: + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -372,7 +472,12 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg == 1: + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 @@ -388,6 +493,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -408,6 +518,7 @@ def main(): params=params, model=model, sp=sp, + decoding_graph=decoding_graph, ) save_results( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 47a519dc9..13e45e03b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -64,6 +64,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size + self.vocab_size = vocab_size if context_size > 1: self.conv = ScaledConv1d( in_channels=embedding_dim, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 851822aae..d28a8a060 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -36,7 +36,7 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import k2 import sentencepiece as spm @@ -48,6 +48,7 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor @@ -55,8 +56,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall import diagnostics @@ -112,6 +114,15 @@ def get_parser(): """, ) + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -192,6 +203,30 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + return parser @@ -320,15 +355,16 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: +) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, and `best_valid_loss` in `params`. Args: @@ -338,20 +374,22 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. Returns: - Return None. + Return a dict containing previously saved training info. """ - if params.start_epoch <= 0: - return + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( filename, model=model, optimizer=optimizer, - scheduler=scheduler, ) keys = [ @@ -360,10 +398,13 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", + "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] + params["start_epoch"] = saved_params["cur_epoch"] + return saved_params @@ -371,7 +412,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -381,6 +422,10 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. """ if rank != 0: return @@ -390,7 +435,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, - scheduler=scheduler, + sampler=sampler, rank=rank, ) @@ -509,6 +554,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -531,12 +577,21 @@ def train_one_epoch( Writer to write log messages to tensorboard. world_size: Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. """ model.train() tot_loss = MetricsTracker() + cur_batch_idx = params.get("cur_batch_idx", 0) + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -560,6 +615,27 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, " @@ -688,7 +764,14 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = librispeech.train_dataloaders(train_cuts) + if checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) valid_cuts = librispeech.dev_clean_cuts() valid_cuts += librispeech.dev_other_cuts() @@ -728,6 +811,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) if params.print_diagnostics: @@ -738,6 +822,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, + sampler=train_dl.sampler, rank=rank, ) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa0..06eacd736 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: + print("Error getting eigenvalues, trying another method") + eigs, _ = torch.eigs(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) From 05e30d0c461f2428a12a8a13d980f14320bf13be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 21:15:00 +0800 Subject: [PATCH 108/234] Add max-abs=6, debugged version --- .../ASR/pruned_transducer_stateless2/conformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index cb4652840..c6470b4a2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -176,13 +176,13 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, - max_positive=0.55) + max_positive=0.55, + max_abs=6.0) self.dropout = nn.Dropout(dropout) @@ -232,7 +232,7 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.balancer(self.norm_final(self.pre_norm_final(src))) + src = self.norm_final(self.balancer(src)) return src From 11a04c50ae15505c7c480963203531abe0c65e98 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 21:29:24 +0800 Subject: [PATCH 109/234] Change 0.025,0.05 to 0.01 in initializations --- .../ASR/pruned_transducer_stateless2/conformer.py | 4 ++-- .../ASR/pruned_transducer_stateless2/scaling.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index c6470b4a2..f778c9226 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -440,8 +440,8 @@ class RelPositionMultiheadAttention(nn.Module): return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.05) - nn.init.normal_(self.pos_bias_v, std=0.05) + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) def forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index b358e5fa2..f2423492f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -153,7 +153,7 @@ class ScaledLinear(nn.Linear): self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - std = 0.025 + std = 0.01 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -188,7 +188,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.025 + std = 0.01 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -229,7 +229,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - std = 0.025 + std = 0.01 a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -451,7 +451,7 @@ class ScaledEmbedding(nn.Module): def reset_parameters(self) -> None: - std = 0.025 + std = 0.01 nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 2eef001d39dcd68f230a8072cac9350b78b9f950 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 23:59:26 +0800 Subject: [PATCH 110/234] Fix balancer code --- egs/librispeech/ASR/transducer_stateless/conformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index bf96b41f9..909f9a74c 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -176,13 +176,13 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.pre_norm_final = Identity() self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, - max_positive=0.55) + max_positive=0.55, + max_positive=6.0) self.dropout = nn.Dropout(dropout) @@ -232,7 +232,7 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(src)) - src = self.balancer(self.norm_final(self.pre_norm_final(src))) + src = self.norm_final(self.balancer(src)) return src From b7e84d5d77cb313579d54b58cc4be3f660af9038 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 21 Mar 2022 23:59:53 +0800 Subject: [PATCH 111/234] Whitespace fix --- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 909f9a74c..f7b96a6a1 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -230,7 +230,7 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(self.conv_module(src)) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) From b82a505dfc003ea9b919dc49d60d780837e927bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 12:30:48 +0800 Subject: [PATCH 112/234] Reduce initial pruned_loss scale from 0.01 to 0.0 --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d28a8a060..b9409127e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -496,7 +496,7 @@ def compute_loss( warmup_mode=warmup_mode, ) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.01 if warmup_mode else pruned_loss)) + (pruned_loss * 0.0 if warmup_mode else pruned_loss)) assert loss.requires_grad == is_training From 4004ca81d84cda612265bd6919bea168e39601da Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 13:32:24 +0800 Subject: [PATCH 113/234] Increase warm_step (and valid_interval) --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b9409127e..096f93d77 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -295,7 +295,7 @@ def get_params() -> AttributeDict: # parameters for decoder "embedding_dim": 512, # parameters for Noam - "warm_step": 30000, # For the 100h subset, use 8k + "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } @@ -689,8 +689,8 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) if params.full_libri is False: - params.valid_interval = 800 - params.warm_step = 16000 + params.valid_interval = 1600 + params.warm_step = 30000 fix_random_seed(params.seed) if world_size > 1: From cef634870300c6d2a00f6b538ccc5d64975d0766 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 13:50:54 +0800 Subject: [PATCH 114/234] Change max-abs from 6 to 10 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f778c9226..d90dd34e1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -182,7 +182,7 @@ class ConformerEncoderLayer(nn.Module): self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55, - max_abs=6.0) + max_abs=10.0) self.dropout = nn.Dropout(dropout) From 9a8aa1f54ab4154571974eea3c795f0b7ad49758 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 22 Mar 2022 15:36:20 +0800 Subject: [PATCH 115/234] Change how warmup works. --- .../pruned_transducer_stateless2/conformer.py | 221 +++--------------- .../ASR/pruned_transducer_stateless2/model.py | 7 +- .../ASR/pruned_transducer_stateless2/train.py | 13 +- 3 files changed, 38 insertions(+), 203 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index d90dd34e1..83bcc3f3e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -88,7 +88,7 @@ class Conformer(Transformer): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -97,6 +97,10 @@ class Conformer(Transformer): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) @@ -113,7 +117,7 @@ class Conformer(Transformer): mask = make_pad_mask(lengths) x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup_mode=warmup_mode) # (T, N, C) + warmup=warmup) # (T, N, C) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -193,6 +197,8 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + warmup: float = 1.0, + position: float = 0.0 ) -> Tensor: """ Pass the input through the encoder layer. @@ -202,6 +208,11 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + warmup: controls selective activation of layers; if < 1.0, it's possible that + not all modules will be included. + position: the position of this module in the encoder stack (relates to + warmup); a value 0 <= position < 1.0. + Shape: src: (S, N, E). @@ -210,9 +221,9 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) + src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), + alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) # multi-headed self-attention module @@ -224,13 +235,16 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = torch.add(src, self.dropout(src_att), + alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) # convolution module - src = src + self.dropout(self.conv_module(src)) + src = torch.add(src, self.dropout(self.conv_module(src)), + alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + src = torch.add(src, self.dropout(self.feed_forward(src)), + alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) src = self.norm_final(self.balancer(src)) @@ -262,10 +276,6 @@ class ConformerEncoder(nn.Module): assert num_layers - 1 not in aux_layers self.num_layers = num_layers num_channels = encoder_layer.d_model - self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0) def forward( self, @@ -273,7 +283,7 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup_mode: bool = False + warmup: float = 1.0 ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -293,7 +303,7 @@ class ConformerEncoder(nn.Module): """ output = src - outputs = [] + num_layers = len(self.layers) for i, mod in enumerate(self.layers): output = mod( @@ -301,11 +311,10 @@ class ConformerEncoder(nn.Module): pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + position=(i / num_layers), ) - if i in self.aux_layers: - outputs.append(output) - output = self.combiner(outputs, warmup_mode) return output @@ -922,187 +931,9 @@ class Identity(torch.nn.Module): return x -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - def __init__(self, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0) -> None: - """ - Args: - num_inputs: The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, - or combinations of layers, to use, is conceptually as follows. - With probability `pure_prob`: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super(RandomCombine, self).__init__() - assert pure_prob >= 0 and pure_prob <= 1 - assert final_weight > 0 and final_weight < 1 - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev= stddev - - self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - - - def forward(self, inputs: Sequence[Tensor], - warmup_mode: bool) -> Tensor: - """ - Forward function. - Args: - inputs: a list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - a Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not (self.training and warmup_mode): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, - num_frames) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: - """ - Return a tensor of random weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), such that - ans.sum(dim=1) is all ones. - - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) - - def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with - exactly one weight equal to 1.0 on each frame. - """ - - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, - final, nonfinal) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) - return ans - - - def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that - sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. - """ - logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev - logprobs[:,-1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") - num_inputs = 3 - num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev) - - x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - - y = m(x, True) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. if __name__ == '__main__': - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 @@ -1110,4 +941,4 @@ if __name__ == '__main__': # Just make sure the forward pass runs. f = c(torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup_mode=True) + warmup=0.5) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index e83d18e3e..faaebc477 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,7 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - warmup_mode: bool = False + warmup: float = 1.0, ) -> torch.Tensor: """ Args: @@ -87,6 +87,9 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. Returns: Return the transducer loss. @@ -102,7 +105,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode=warmup_mode) + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 096f93d77..d4a2e83d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -454,7 +454,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup_mode: bool = False + warmup: float = 1.0 ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -471,6 +471,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. """ device = model.device feature = batch["inputs"] @@ -493,10 +495,10 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - warmup_mode=warmup_mode, + warmup=warmup, ) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.0 if warmup_mode else pruned_loss)) + (pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) assert loss.requires_grad == is_training @@ -601,7 +603,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup_mode=(params.batch_idx_train < params.model_warm_step) + warmup=(params.batch_idx_train / params.model_warm_step) ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -855,7 +857,6 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup_mode=True # may use slightly more memory ) loss.backward() optimizer.step() From aab72bc2a546872ac08a4396b382810b90af1cba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Mar 2022 13:10:54 +0800 Subject: [PATCH 116/234] Add changes from master to decode.py, train.py --- .../pruned_transducer_stateless2/decode.py | 27 ++++++++++++++----- .../ASR/pruned_transducer_stateless2/train.py | 19 ++++++++++--- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index ad76411c0..8e924bf96 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -71,6 +71,7 @@ from beam_search import ( beam_search, fast_beam_search, greedy_search, + greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model @@ -191,7 +192,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -261,6 +262,24 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -280,12 +299,6 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d4a2e83d5..01cf289f5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -398,12 +398,16 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", - "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] - params["start_epoch"] = saved_params["cur_epoch"] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] return saved_params @@ -762,11 +766,20 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) - if checkpoints and "sampler" in checkpoints: + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None From 1f548548d2875ebc6ec7f7d526d0500c5e83b18e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 24 Mar 2022 15:06:06 +0800 Subject: [PATCH 117/234] Simplify the warmup code; max_abs 10->6 --- .../pruned_transducer_stateless2/conformer.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 83bcc3f3e..a81777353 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -186,7 +186,7 @@ class ConformerEncoderLayer(nn.Module): self.balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, max_positive=0.55, - max_abs=10.0) + max_abs=6.0) self.dropout = nn.Dropout(dropout) @@ -198,7 +198,6 @@ class ConformerEncoderLayer(nn.Module): src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - position: float = 0.0 ) -> Tensor: """ Pass the input through the encoder layer. @@ -208,11 +207,10 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective activation of layers; if < 1.0, it's possible that - not all modules will be included. - position: the position of this module in the encoder stack (relates to - warmup); a value 0 <= position < 1.0. - + warmup: controls selective activation of layers; if < 0.5, it's possible that + not all modules will be included. Actually we add the + feed_forward_macaron and self_attn modules at warmup=0.0 + and the conv_module and feed_forward at warmup=0.5. Shape: src: (S, N, E). @@ -223,7 +221,7 @@ class ConformerEncoderLayer(nn.Module): """ # macaron style feed forward module src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), - alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0)) + alpha=(0.0 if warmup < 0.0 else 1.0)) # multi-headed self-attention module @@ -236,15 +234,15 @@ class ConformerEncoderLayer(nn.Module): key_padding_mask=src_key_padding_mask, )[0] src = torch.add(src, self.dropout(src_att), - alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0)) + alpha=(0.0 if warmup < 0.0 else 1.0)) # convolution module src = torch.add(src, self.dropout(self.conv_module(src)), - alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0)) + alpha=(0.0 if warmup < 0.5 else 1.0)) # feed forward module src = torch.add(src, self.dropout(self.feed_forward(src)), - alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0)) + alpha=(0.0 if warmup < 0.5 else 1.0)) src = self.norm_final(self.balancer(src)) @@ -311,8 +309,7 @@ class ConformerEncoder(nn.Module): pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - position=(i / num_layers), + warmup=warmup-0.5*(i / num_layers) ) return output From 4b650e9f015a8cef28f5a2a0574b8b3d250fcea8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Mar 2022 20:34:33 +0800 Subject: [PATCH 118/234] Make warmup work by scaling layer contributions; leave residual layer-drop --- .../pruned_transducer_stateless2/conformer.py | 32 +++++++++++++------ .../ASR/pruned_transducer_stateless2/train.py | 11 +++++-- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a81777353..64030ef90 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -219,9 +219,23 @@ class ConformerEncoderLayer(nn.Module): src_key_padding_mask: (N, S). S is the source sequence length, N is the batch size, E is the feature number """ + src_orig = src + # when warmup == 0.0, alpha is always 0.1, but it gradually changes to + # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not + # 0.0 is that it gives us a gradient so we can learn something when we are not + # being very useful. The occasional 1.0 will ensure, via self.balancer, that + # the outputs of our modules don't get scaled up too much. + + # min(0.1, warmup) + # is used in place of warmup to ensure that even at the start of the warm-up + # period we sometimes use scale 1.0; this ensures that the modules do not + # compensate for the small scale by just producing larger output. + warmup = max(warmup, 0.1) + warmup = min(warmup, 0.95) # effectively, layer-drop. + alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 + # macaron style feed forward module - src = torch.add(src, self.dropout(self.feed_forward_macaron(src)), - alpha=(0.0 if warmup < 0.0 else 1.0)) + src = torch.add(src, self.dropout(self.feed_forward_macaron(src))) # multi-headed self-attention module @@ -233,19 +247,19 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = torch.add(src, self.dropout(src_att), - alpha=(0.0 if warmup < 0.0 else 1.0)) + src = torch.add(src, self.dropout(src_att)) # convolution module - src = torch.add(src, self.dropout(self.conv_module(src)), - alpha=(0.0 if warmup < 0.5 else 1.0)) + src = torch.add(src, self.dropout(self.conv_module(src))) # feed forward module - src = torch.add(src, self.dropout(self.feed_forward(src)), - alpha=(0.0 if warmup < 0.5 else 1.0)) + src = torch.add(src, self.dropout(self.feed_forward(src))) src = self.norm_final(self.balancer(src)) + if alpha != 1.0: + src = alpha * src + (1-alpha) * src_orig + return src @@ -309,7 +323,7 @@ class ConformerEncoder(nn.Module): pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - warmup=warmup-0.5*(i / num_layers) + warmup=warmup, ) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 01cf289f5..35991f5e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 4000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -501,8 +501,15 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = (0.0 if warmup < 1.0 else + (0.1 if warmup > 1.0 and warmup < 2.0) else + 1.0) loss = (params.simple_loss_scale * simple_loss + - (pruned_loss * 0.0 if warmup < 1.0 else pruned_loss)) + pruned_loss_scale * pruned_loss) assert loss.requires_grad == is_training From d2ed3dfc90fa05c63433c5cc7e627bb03de209cc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Mar 2022 20:35:11 +0800 Subject: [PATCH 119/234] Fix bug --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 35991f5e9..13ba99017 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -506,8 +506,8 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = (0.0 if warmup < 1.0 else - (0.1 if warmup > 1.0 and warmup < 2.0) else - 1.0) + (0.1 if warmup > 1.0 and warmup < 2.0 else + 1.0)) loss = (params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss) From 0e694739f2a33344eac9cf8b0398aa876f469853 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 25 Mar 2022 23:28:52 +0800 Subject: [PATCH 120/234] Fix test mode with random layer dropout --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 64030ef90..fae91aa71 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -231,7 +231,8 @@ class ConformerEncoderLayer(nn.Module): # period we sometimes use scale 1.0; this ensures that the modules do not # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) - warmup = min(warmup, 0.95) # effectively, layer-drop. + if self.training: + warmup = min(warmup, 0.95) # effectively, layer-drop. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module From 26a1730392163e64499672c2847ef2e10bf3bc5e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Mar 2022 14:46:27 +0800 Subject: [PATCH 121/234] Add random-number-setting function in dataloader --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a460c8eb8..a0356f68a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -22,6 +22,8 @@ import logging from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import torch +import lhotse from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( @@ -301,12 +303,19 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have previously been + # set in the main process. + seed = torch.randint(0, 100000, ()).item() + def worker_init_fn(worker_id: int): + lhotse.utils.fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl From 8a38d9a855b57be5e976727084d4980aa0fd5b2a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Mar 2022 15:43:47 +0800 Subject: [PATCH 122/234] Fix/patch how fix_random_seed() is imported. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a0356f68a..3efe7ec7a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,7 +23,7 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional import torch -import lhotse +from lhotse.utils import fix_random_seed from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( @@ -307,7 +307,7 @@ class LibriSpeechAsrDataModule: # set in the main process. seed = torch.randint(0, 100000, ()).item() def worker_init_fn(worker_id: int): - lhotse.utils.fix_random_seed(seed + worker_id) + fix_random_seed(seed + worker_id) train_dl = DataLoader( train, From b43468bb67502b87296387b6a65048a85558ab04 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 26 Mar 2022 19:36:33 +0800 Subject: [PATCH 123/234] Reduce layer-drop prob --- .../pruned_transducer_stateless2/conformer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index fae91aa71..69a7af6a9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -222,21 +222,20 @@ class ConformerEncoderLayer(nn.Module): src_orig = src # when warmup == 0.0, alpha is always 0.1, but it gradually changes to # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not - # 0.0 is that it gives us a gradient so we can learn something when we are not - # being very useful. The occasional 1.0 will ensure, via self.balancer, that - # the outputs of our modules don't get scaled up too much. - + # 0.0 is that it gives us a gradient so we can learn something when we are turned + # off. + # # min(0.1, warmup) # is used in place of warmup to ensure that even at the start of the warm-up # period we sometimes use scale 1.0; this ensures that the modules do not # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: - warmup = min(warmup, 0.95) # effectively, layer-drop. + warmup = min(warmup, 0.98) # effectively, layer-drop with 1-in-50 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module - src = torch.add(src, self.dropout(self.feed_forward_macaron(src))) + src = src + self.dropout(self.feed_forward_macaron(src)) # multi-headed self-attention module @@ -248,13 +247,13 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = torch.add(src, self.dropout(src_att)) + src = src + self.dropout(src_att) # convolution module - src = torch.add(src, self.dropout(self.conv_module(src))) + src = src + self.dropout(self.conv_module(src)) # feed forward module - src = torch.add(src, self.dropout(self.feed_forward(src))) + src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) From 953aecf5e38811edc11123d26292cd1d397e11aa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Mar 2022 00:25:32 +0800 Subject: [PATCH 124/234] Reduce layer-drop prob after warmup to 1 in 100 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 69a7af6a9..85a3b4575 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -231,7 +231,7 @@ class ConformerEncoderLayer(nn.Module): # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: - warmup = min(warmup, 0.98) # effectively, layer-drop with 1-in-50 prob. + warmup = min(warmup, 0.99) # effectively, layer-drop with 1-in-100 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module From 8a8134b9e54e0a1b1cbda59cc1a38fe7cccb16b5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 17 Mar 2022 13:18:58 +0800 Subject: [PATCH 125/234] Change power of lr-schedule from -0.5 to -0.333 --- .../ASR/pruned_transducer_stateless2/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py index 3fa847f4f..aa091877c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -391,7 +391,8 @@ class Noam(object): return ( self.factor * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) ) def zero_grad(self): From 262388134d3b31dc6ec42fa15b99804d23e23d44 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Mar 2022 11:18:16 +0800 Subject: [PATCH 126/234] Increase model_warm_step to 4k --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 13ba99017..c1e836903 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -296,7 +296,7 @@ def get_params() -> AttributeDict: "embedding_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } ) From 2cde99509fda0dc6fec55eab7504cacc44b6c0fc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 27 Mar 2022 23:21:42 +0800 Subject: [PATCH 127/234] Change max-keep-prob to 0.95 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 85a3b4575..9c8302926 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -231,7 +231,7 @@ class ConformerEncoderLayer(nn.Module): # compensate for the small scale by just producing larger output. warmup = max(warmup, 0.1) if self.training: - warmup = min(warmup, 0.99) # effectively, layer-drop with 1-in-100 prob. + warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob. alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 # macaron style feed forward module From 11124b03eaed547e057f302bf02d0d75b91ae58b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Mar 2022 20:32:14 +0800 Subject: [PATCH 128/234] Refactoring and simplifying conformer and frontend --- .../pruned_transducer_stateless2/conformer.py | 115 +++++++++++++--- .../subsampling.py | 127 +++--------------- .../ASR/pruned_transducer_stateless2/train.py | 2 - .../transformer.py | 5 +- 4 files changed, 115 insertions(+), 134 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 9c8302926..6b625513e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -16,6 +16,7 @@ # limitations under the License. import copy +from encoder_interface import EncoderInterface import math import warnings from typing import Optional, Tuple, Sequence @@ -23,12 +24,11 @@ from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, Sc import torch from torch import Tensor, nn -from transformer import Transformer from icefall.utils import make_pad_mask -class Conformer(Transformer): +class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features @@ -40,7 +40,6 @@ class Conformer(Transformer): num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. """ @@ -55,22 +54,22 @@ class Conformer(Transformer): num_encoder_layers: int = 12, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, aux_layer_period: int = 3 ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - output_dim=output_dim, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - ) + super(Conformer, self).__init__() + + self.num_features = num_features + self.output_dim = output_dim + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -80,11 +79,13 @@ class Conformer(Transformer): dim_feedforward, dropout, cnn_module_kernel, - normalize_before, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.normalize_before = normalize_before + + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) + ) def forward( @@ -136,7 +137,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -152,7 +152,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() self.d_model = d_model @@ -942,6 +941,80 @@ class Identity(torch.nn.Module): return x +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, in_channels: int, + out_channels: int, + layer1_channels: int = 64, + layer2_channels: int = 128) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + """ + assert in_channels >= 7 + super().__init__() + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, out_channels=layer1_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, out_channels=layer2_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer(channel_dim=-1, + min_positive=0.45, + max_positive=0.55) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x if __name__ == '__main__': diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py index c2da23adc..12ca09a17 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py @@ -32,34 +32,43 @@ class Conv2dSubsampling(nn.Module): https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa """ - def __init__(self, idim: int, odim: int) -> None: + def __init__(self, in_channels: int, + out_channels: int, + layer1_channels: int = 64, + layer2_channels: int = 128) -> None: """ Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 """ - assert idim >= 7 + assert in_channels >= 7 super().__init__() self.conv = nn.Sequential( ScaledConv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 + in_channels=1, out_channels=layer1_channels, + kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + in_channels=layer1_channels, out_channels=layer2_channels, + kernel_size=3, stride=2 ), ActivationBalancer(channel_dim=1), DoubleSwish(), ) - self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) + self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. - self.out_norm = BasicNorm(odim, learn_eps=False) + self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.45, @@ -86,99 +95,3 @@ class Conv2dSubsampling(nn.Module): x = self.out_norm(x) x = self.out_balancer(x) return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >=7, idim >=7 - odim: - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) - self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - x = self.out_norm(x) - x = self.out_balancer(x) - return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c1e836903..237eb8bbd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -291,7 +291,6 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, - "vgg_frontend": False, # parameters for decoder "embedding_dim": 512, # parameters for Noam @@ -314,7 +313,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, ) return encoder diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py index aa091877c..a58702e1d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py @@ -78,10 +78,7 @@ class Transformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = PositionalEncoding(d_model, dropout) From 4e453a4bf9c77bfaa19a955921a9e7218548b2eb Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 29 Mar 2022 23:41:13 +0800 Subject: [PATCH 129/234] Rework conformer, remove some code. --- .../pruned_transducer_stateless2/conformer.py | 90 +++- .../subsampling.py | 97 ---- .../ASR/pruned_transducer_stateless2/train.py | 3 +- .../transformer.py | 416 ------------------ 4 files changed, 90 insertions(+), 516 deletions(-) delete mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py delete mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 6b625513e..0b9d64ee9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -69,7 +69,7 @@ class Conformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, 128, d_model) + self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -1017,6 +1017,94 @@ class Conv2dSubsampling(nn.Module): return x +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + if __name__ == '__main__': feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py deleted file mode 100644 index 12ca09a17..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/subsampling.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn -from torch import Tensor -from typing import Tuple -from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__(self, in_channels: int, - out_channels: int, - layer1_channels: int = 64, - layer2_channels: int = 128) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - """ - assert in_channels >= 7 - super().__init__() - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, out_channels=layer2_channels, - kernel_size=3, stride=2 - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 237eb8bbd..8d5142937 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -44,7 +44,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer +from conformer import Conformer, Noam from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -54,7 +54,6 @@ from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from transformer import Noam from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py deleted file mode 100644 index a58702e1d..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/transformer.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear - -from icefall.utils import make_pad_mask - - -class Transformer(EncoderInterface): - def __init__( - self, - num_features: int, - output_dim: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - output_dim: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder. - num_encoder_layers: - Number of encoder layers. - dropout: - Dropout in encoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - """ - super().__init__() - - self.num_features = num_features - self.output_dim = output_dim - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) - ) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - Returns: - Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 - assert x.size(0) == lengths.max().item() - - mask = make_pad_mask(lengths) - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return logits, lengths - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) From 1b8d7defd06c7e12b47e0bde0c8092a053d7f377 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 00:44:18 +0800 Subject: [PATCH 130/234] Reduce 1st conv channels from 64 to 32 --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0b9d64ee9..628d31d4b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -954,7 +954,7 @@ class Conv2dSubsampling(nn.Module): def __init__(self, in_channels: int, out_channels: int, - layer1_channels: int = 64, + layer1_channels: int = 32, layer2_channels: int = 128) -> None: """ Args: From ca6337b78aaedff4404135558cf99f9ad7ab7123 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:11:32 +0800 Subject: [PATCH 131/234] Add another convolutional layer --- .../ASR/pruned_transducer_stateless2/conformer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 628d31d4b..eb937e0c3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -954,8 +954,9 @@ class Conv2dSubsampling(nn.Module): def __init__(self, in_channels: int, out_channels: int, - layer1_channels: int = 32, - layer2_channels: int = 128) -> None: + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128) -> None: """ Args: in_channels: @@ -973,7 +974,7 @@ class Conv2dSubsampling(nn.Module): self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 + kernel_size=3, ), ActivationBalancer(channel_dim=1), DoubleSwish(), @@ -983,8 +984,14 @@ class Conv2dSubsampling(nn.Module): ), ActivationBalancer(channel_dim=1), DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, out_channels=layer3_channels, + kernel_size=3, stride=2 + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), ) - self.out = ScaledLinear(layer2_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. From 21a099b110e2831664cd79b5fd982b87606c512c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:18:04 +0800 Subject: [PATCH 132/234] Fix padding bug --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index eb937e0c3..853d6747b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -974,7 +974,7 @@ class Conv2dSubsampling(nn.Module): self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, + kernel_size=3, padding=1, ), ActivationBalancer(channel_dim=1), DoubleSwish(), From 7c46c3b0d4ae7237ea1ac909d44a605507c27a77 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:20:04 +0800 Subject: [PATCH 133/234] Remove dropout in output layer --- .../ASR/pruned_transducer_stateless2/conformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0b9d64ee9..a8475c21e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -83,10 +83,7 @@ class Conformer(EncoderInterface): self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim) - ) - + self.encoder_output_layer = ScaledLinear(d_model, output_dim) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 From 37ab0bcfa56fa063c7e5aadfe3f0f207c53e5518 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 11:46:23 +0800 Subject: [PATCH 134/234] Reduce speed of some components --- .../pruned_transducer_stateless2/conformer.py | 12 +++- .../pruned_transducer_stateless2/decoder.py | 7 ++ .../pruned_transducer_stateless2/scaling.py | 67 +++++++++++++------ 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index a8475c21e..0d3b0aa02 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -967,16 +967,24 @@ class Conv2dSubsampling(nn.Module): """ assert in_channels >= 7 super().__init__() + + # This initial_speed is to slightly slow down the relative speed of + # training during the warmup phase by increasing the magnitude of the + # initial parameter values. The intention is to allow us to + # use a higher lr_factor. + initial_speed = 0.5 self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, - kernel_size=3, stride=2 + kernel_size=3, stride=2, + initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, - kernel_size=3, stride=2 + kernel_size=3, stride=2, + initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 13e45e03b..3470b647f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -55,10 +55,17 @@ class Decoder(nn.Module): 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() + + # This initial_speed is to slightly slow down the relative speed of + # training during the warmup phase by increasing the magnitude of the + # initial parameter values. The intention is to allow us to + # use a higher lr_factor. + initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, + initial_speed=initial_speed ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f2423492f..4c45205ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -134,13 +134,18 @@ class ScaledLinear(nn.Linear): (affects the initialization of weight_scale and bias_scale). Another option, if you want to do something like this, is to re-initialize the parameters. - - Note: it uses the default initialization for the weight and bias, - inherited from nn.Linear. For modules with small fan-in, this - may be larger than optimal. + initial_speed: this affects how fast the parameter will + learn near the start of training; you can set it to a + value less than one if you suspect that a module + is contributing to instability near the start of training. + Nnote: regardless of the use of this option, it's best to + use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. """ def __init__(self, *args, initial_scale: float = 1.0, + initial_speed: float = 1.0, **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -150,10 +155,10 @@ class ScaledLinear(nn.Linear): else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in nn.Linear + self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -176,8 +181,11 @@ class ScaledLinear(nn.Linear): class ScaledConv1d(nn.Conv1d): + # See docs for ScaledLinear def __init__(self, *args, - initial_scale=1.0, **kwargs): + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -185,10 +193,10 @@ class ScaledConv1d(nn.Conv1d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class + self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -218,7 +226,11 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + # See docs for ScaledLinear + def __init__(self, *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) @@ -226,10 +238,10 @@ class ScaledConv2d(nn.Conv2d): self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class + self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class - def _reset_parameters(self): - std = 0.01 + def _reset_parameters(self, initial_speed: float): + std = 0.01 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -350,7 +362,11 @@ class DoubleSwish(torch.nn.Module): class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. + r"""This is a modified version of nn.Embedding that introduces a learnable scale + on the parameters. Note: due to how we initialize it, it's best used with + schedulers like Noam that have a warmup period. + + It is a simple lookup table that stores embeddings of a fixed dictionary and size. This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding @@ -369,6 +385,15 @@ class ScaledEmbedding(nn.Module): sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. + initial_speed (float, optional): This affects how fast the parameter will + learn near the start of training; you can set it to a value less than + one if you suspect that a module is contributing to instability near + the start of training. Nnote: regardless of the use of this option, + it's best to use schedulers like Noam that have a warm-up period. + Alternatively you can set it to more than 1 if you want it to + initially train faster. Must be greater than 0. + + Attributes: weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from :math:`\mathcal{N}(0, 1)` @@ -416,6 +441,7 @@ class ScaledEmbedding(nn.Module): [ 0.1535, -2.0309, 0.9315], [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) + """ __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'scale_grad_by_freq', 'sparse'] @@ -429,7 +455,8 @@ class ScaledEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, scale_grad_by_freq: bool = False, - sparse: bool = False) -> None: + sparse: bool = False, + initial_speed: float = 1.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -446,12 +473,12 @@ class ScaledEmbedding(nn.Module): self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() + self.reset_parameters(initial_speed) - def reset_parameters(self) -> None: - std = 0.01 + def reset_parameters(self, initial_speed: float = 1.0) -> None: + std = 0.01 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 709c387ce63a9e1eeee2e34de831f27f2b72b9cf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 21:40:22 +0800 Subject: [PATCH 135/234] Initial refactoring to remove unnecessary vocab_size --- .../pruned_transducer_stateless2/conformer.py | 25 +++++++++++-------- .../pruned_transducer_stateless2/decoder.py | 3 +-- .../pruned_transducer_stateless2/joiner.py | 9 +++---- .../ASR/pruned_transducer_stateless2/model.py | 11 ++++++-- .../ASR/pruned_transducer_stateless2/train.py | 7 +++--- 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index d8b184752..03a47927f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,9 +32,10 @@ class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Number of output dimension + output_dim (int): Model output dimension. If not equal to the encoder dimension, + we will project to the output. subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension + d_model (int): attention dimension, also the output dimension nhead (int): number of head dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers @@ -42,7 +43,6 @@ class Conformer(EncoderInterface): cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ - def __init__( self, num_features: int, @@ -59,7 +59,6 @@ class Conformer(EncoderInterface): super(Conformer, self).__init__() self.num_features = num_features - self.output_dim = output_dim self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") @@ -83,7 +82,11 @@ class Conformer(EncoderInterface): self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) - self.encoder_output_layer = ScaledLinear(d_model, output_dim) + if output_dim == d_model: + self.encoder_output_layer = Identity() + else: + self.encoder_output_layer = ScaledLinear(d_model, output_dim, + initial_speed=0.5) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -101,9 +104,9 @@ class Conformer(EncoderInterface): to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. + - embeddings: its shape is (batch_size, output_seq_len, d_model) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. """ x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) @@ -117,10 +120,10 @@ class Conformer(EncoderInterface): x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - logits = self.encoder_output_layer(x) - logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths + return x, lengths class ConformerEncoderLayer(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 3470b647f..a442feeea 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -68,6 +68,7 @@ class Decoder(nn.Module): initial_speed=initial_speed ) self.blank_id = blank_id + self.output_linear = ScaledLinear(embedding_dim, embedding_dim) assert context_size >= 1, context_size self.context_size = context_size @@ -81,8 +82,6 @@ class Decoder(nn.Module): groups=embedding_dim, bias=False, ) - self.output_linear = ScaledLinear(embedding_dim, vocab_size) - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 61bfe8186..973a89bfe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -20,11 +20,10 @@ import torch.nn.functional as F from scaling import ScaledLinear class Joiner(nn.Module): - def __init__(self, input_dim: int, inner_dim: int, output_dim: int): + def __init__(self, input_dim: int, output_dim: int): super().__init__() - self.inner_linear = ScaledLinear(input_dim, inner_dim) - self.output_linear = ScaledLinear(inner_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -43,8 +42,6 @@ class Joiner(nn.Module): logit = encoder_out + decoder_out - logit = self.inner_linear(torch.tanh(logit)) - - output = self.output_linear(F.relu(logit)) + logit = self.output_linear(torch.tanh(logit)) return output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index faaebc477..2f102bdf8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -19,6 +19,7 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos @@ -33,6 +34,8 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, + embedding_dim: int, + vocab_size: int ): """ Args: @@ -58,6 +61,10 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner + # could perhaps separate this into 2 linear projections, one + # for lm and one for am. + self.simple_joiner = nn.Linear(embedding_dim, vocab_size) + def forward( self, x: torch.Tensor, @@ -133,8 +140,8 @@ class Transducer(nn.Module): boundary[:, 3] = x_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=decoder_out, - am=encoder_out, + lm=self.simple_joiner(decoder_out), + am=self.simple_joiner(encoder_out), symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 8d5142937..649234f0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.vocab_size, + output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, + input_dim=params.embedding_dim, output_dim=params.vocab_size, ) return joiner @@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, + embedding_dim=params.embedding_dim, + vocab_size=params.vocab_size, ) return model From f87811e65c1f9cdace638122df5f29c150a50b60 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 21:41:46 +0800 Subject: [PATCH 136/234] Fix RE identity --- .../ASR/pruned_transducer_stateless2/conformer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 03a47927f..528cc48f4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -83,7 +83,7 @@ class Conformer(EncoderInterface): aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) if output_dim == d_model: - self.encoder_output_layer = Identity() + self.encoder_output_layer = nn.Identity() else: self.encoder_output_layer = ScaledLinear(d_model, output_dim, initial_speed=0.5) @@ -936,10 +936,6 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class Identity(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - return x - class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). From a2aca9f64371b3be66cba65bac9f6b60346a9126 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 30 Mar 2022 21:42:15 +0800 Subject: [PATCH 137/234] Bug-fix --- egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 973a89bfe..d76a913a5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -44,4 +44,4 @@ class Joiner(nn.Module): logit = self.output_linear(torch.tanh(logit)) - return output + return logit From 0599f382810c9f9f2bbad39dacd3c8159bd43a06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 11:53:54 +0800 Subject: [PATCH 138/234] Add final dropout to conformer --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 528cc48f4..8d4057e71 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -82,6 +82,7 @@ class Conformer(EncoderInterface): self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.final_dropout = nn.Dropout(p=dropout) if output_dim == d_model: self.encoder_output_layer = nn.Identity() else: @@ -120,6 +121,7 @@ class Conformer(EncoderInterface): x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + x = self.final_dropout(x) x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) From f47fe8337aec12d8d7a005855e763b695f37e9d1 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 12:16:08 +0800 Subject: [PATCH 139/234] Remove some un-used code --- .../ASR/pruned_transducer_stateless2/conformer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 528cc48f4..abe30633c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -54,7 +54,6 @@ class Conformer(EncoderInterface): num_encoder_layers: int = 12, dropout: float = 0.1, cnn_module_kernel: int = 31, - aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__() @@ -79,8 +78,7 @@ class Conformer(EncoderInterface): dropout, cnn_module_kernel, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) if output_dim == d_model: self.encoder_output_layer = nn.Identity() @@ -277,16 +275,13 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int, - aux_layers: Sequence[int]) -> None: + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) - self.aux_layers = set(aux_layers + [num_layers - 1]) - assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.d_model + def forward( self, From f75d40c725f6d9ebacc5e02581066dc5ec4de762 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 12:18:31 +0800 Subject: [PATCH 140/234] Replace nn.Linear with ScaledLinear in simple joiner --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 2f102bdf8..f1a3d4d11 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -63,7 +63,7 @@ class Transducer(nn.Module): # could perhaps separate this into 2 linear projections, one # for lm and one for am. - self.simple_joiner = nn.Linear(embedding_dim, vocab_size) + self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) def forward( self, From c67ae0f3a132189da402eca9d4886e664699862d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:02:40 +0800 Subject: [PATCH 141/234] Make 2 projections.. --- .../ASR/pruned_transducer_stateless2/joiner.py | 3 ++- .../ASR/pruned_transducer_stateless2/model.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index d76a913a5..b9c465398 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -23,7 +23,8 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): super().__init__() - self.output_linear = ScaledLinear(input_dim, output_dim) + self.output_linear = ScaledLinear(input_dim, output_dim, + initial_speed=0.5) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index f1a3d4d11..ab729a429 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -61,9 +61,15 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - # could perhaps separate this into 2 linear projections, one - # for lm and one for am. - self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) + self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size) + self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size) + with torch.no_grad(): + # Initialize the two projections to be the same; this will be + # convenient for the real joiner, which adds the endcoder + # (acoustic-model/am) and decoder (language-model/lm) embeddings + self.simple_lm_proj.weight[:] = self.simple_am_proj.weight + self.simple_lm_proj.bias[:] = self.simple_am_proj.bias + def forward( self, @@ -140,8 +146,8 @@ class Transducer(nn.Module): boundary[:, 3] = x_lens simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=self.simple_joiner(decoder_out), - am=self.simple_joiner(encoder_out), + lm=self.simple_lm_proj(decoder_out), + am=self.simple_am_proj(encoder_out), symbols=y_padded, termination_symbol=blank_id, lm_only_scale=lm_scale, From e59db01b7c599afb0c780a33289b4b86c6579afe Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:03:26 +0800 Subject: [PATCH 142/234] Reduce initial_speed --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index ab729a429..0355c4531 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -61,8 +61,10 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size) - self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size) + self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, + initial_speed=0.5) + self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, + initial_speed=0.5) with torch.no_grad(): # Initialize the two projections to be the same; this will be # convenient for the real joiner, which adds the endcoder From ec54fa85cc9cd8cba6b87c0599464fa499523e27 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:04:09 +0800 Subject: [PATCH 143/234] Use initial_speed=0.5 --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index f1a3d4d11..9fef48fcc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -63,7 +63,8 @@ class Transducer(nn.Module): # could perhaps separate this into 2 linear projections, one # for lm and one for am. - self.simple_joiner = ScaledLinear(embedding_dim, vocab_size) + self.simple_joiner = ScaledLinear(embedding_dim, vocab_size, + initial_speed=0.5) def forward( self, From 025d6909951502ec35187129c4b05b7d40f3b85b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:39:56 +0800 Subject: [PATCH 144/234] Reduce initial_speed further from 0.5 to 0.25 --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 9fef48fcc..83405be36 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -64,7 +64,7 @@ class Transducer(nn.Module): # could perhaps separate this into 2 linear projections, one # for lm and one for am. self.simple_joiner = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.5) + initial_speed=0.25) def forward( self, From fcb0dba2cfe1d84ac10472dc3a745ed936053246 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 13:47:28 +0800 Subject: [PATCH 145/234] Reduce initial_speed from 0.5 to 0.25 --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 0355c4531..47a7169b1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -62,9 +62,9 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.5) + initial_speed=0.25) self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.5) + initial_speed=0.25) with torch.no_grad(): # Initialize the two projections to be the same; this will be # convenient for the real joiner, which adds the endcoder From e6637132584c1c0287e8cdc80cb0f0e5b22cce6b Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 14:43:49 +0800 Subject: [PATCH 146/234] Change how warmup is applied. --- .../pruned_transducer_stateless2/conformer.py | 24 ++++++------------- .../ASR/pruned_transducer_stateless2/model.py | 2 +- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 12095810e..704c17dd7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -206,10 +206,8 @@ class ConformerEncoderLayer(nn.Module): pos_emb: Positional embedding tensor (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective activation of layers; if < 0.5, it's possible that - not all modules will be included. Actually we add the - feed_forward_macaron and self_attn modules at warmup=0.0 - and the conv_module and feed_forward at warmup=0.5. + warmup: controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. Shape: src: (S, N, E). @@ -219,19 +217,11 @@ class ConformerEncoderLayer(nn.Module): S is the source sequence length, N is the batch size, E is the feature number """ src_orig = src - # when warmup == 0.0, alpha is always 0.1, but it gradually changes to - # always being 1.0 when warmup equals 1.0. The reason for using 0.1 and not - # 0.0 is that it gives us a gradient so we can learn something when we are turned - # off. - # - # min(0.1, warmup) - # is used in place of warmup to ensure that even at the start of the warm-up - # period we sometimes use scale 1.0; this ensures that the modules do not - # compensate for the small scale by just producing larger output. - warmup = max(warmup, 0.1) - if self.training: - warmup = min(warmup, 0.95) # effectively, layer-drop with 1-in-20 prob. - alpha = 1.0 if torch.rand(()).item() <= warmup else 0.1 + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely + # bypass it. + alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 83405be36..9fef48fcc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -64,7 +64,7 @@ class Transducer(nn.Module): # could perhaps separate this into 2 linear projections, one # for lm and one for am. self.simple_joiner = ScaledLinear(embedding_dim, vocab_size, - initial_speed=0.25) + initial_speed=0.5) def forward( self, From 8caa18e2fe1d03035dbfae1a60878cf727861d44 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 31 Mar 2022 17:30:51 +0800 Subject: [PATCH 147/234] Bug fix to warmup_scale --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 704c17dd7..8778dc5ba 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -221,7 +221,7 @@ class ConformerEncoderLayer(nn.Module): warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely # bypass it. - alpha = 0.1 if torch.rand(()).item() <= 0.9 else warmup_scale + alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) From 92ec2e356e02cbc9b5493048d6108b1148de40be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 1 Apr 2022 12:22:12 +0800 Subject: [PATCH 148/234] Fix test-mode --- .../ASR/pruned_transducer_stateless2/conformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 8778dc5ba..83de82056 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -221,7 +221,10 @@ class ConformerEncoderLayer(nn.Module): warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely # bypass it. - alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + if self.training: + alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + else: + alpha = 1.0 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) From 45f872c27da3d52dee592552435a46f7ae2cd374 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 1 Apr 2022 19:33:20 +0800 Subject: [PATCH 149/234] Remove final dropout --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 83de82056..7573addaa 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -80,7 +80,6 @@ class Conformer(EncoderInterface): ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.final_dropout = nn.Dropout(p=dropout) if output_dim == d_model: self.encoder_output_layer = nn.Identity() else: @@ -119,7 +118,6 @@ class Conformer(EncoderInterface): x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - x = self.final_dropout(x) x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) From e0ba4ef3ec7c0ce3ec1b167820b4a711741da137 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 17:47:12 +0800 Subject: [PATCH 150/234] Make layer dropout rate 0.075, was 0.1. --- .../ASR/pruned_transducer_stateless2/conformer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 7573addaa..07ff0525a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -40,6 +40,7 @@ class Conformer(EncoderInterface): dim_feedforward (int): feedforward dimention num_encoder_layers (int): number of encoder layers dropout (float): dropout rate + layer_dropout (float): layer-dropout rate. cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ @@ -53,6 +54,7 @@ class Conformer(EncoderInterface): dim_feedforward: int = 2048, num_encoder_layers: int = 12, dropout: float = 0.1, + layer_dropout: float = 0.075, cnn_module_kernel: int = 31, ) -> None: super(Conformer, self).__init__() @@ -76,6 +78,7 @@ class Conformer(EncoderInterface): nhead, dim_feedforward, dropout, + layer_dropout, cnn_module_kernel, ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) @@ -149,9 +152,13 @@ class ConformerEncoderLayer(nn.Module): nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + layer_dropout: float = 0.075, cnn_module_kernel: int = 31, ) -> None: super(ConformerEncoderLayer, self).__init__() + + self.layer_dropout = layer_dropout + self.d_model = d_model self.self_attn = RelPositionMultiheadAttention( @@ -217,10 +224,10 @@ class ConformerEncoderLayer(nn.Module): src_orig = src warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean completely - # bypass it. + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. if self.training: - alpha = warmup_scale if torch.rand(()).item() <= 0.9 else 0.1 + alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1 else: alpha = 1.0 From 8be10d3d6c39dbb51f932c8cebea7cb67055ed92 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:03:21 +0800 Subject: [PATCH 151/234] First draft of model rework --- .../pruned_transducer_stateless2/conformer.py | 11 +------ .../pruned_transducer_stateless2/decoder.py | 17 +++++----- .../pruned_transducer_stateless2/joiner.py | 16 ++++++--- .../ASR/pruned_transducer_stateless2/model.py | 33 ++++++++----------- .../ASR/pruned_transducer_stateless2/train.py | 22 ++++++++----- 5 files changed, 49 insertions(+), 50 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index c7ce3bec2..0deb960ad 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -32,8 +32,6 @@ class Conformer(EncoderInterface): """ Args: num_features (int): Number of input features - output_dim (int): Model output dimension. If not equal to the encoder dimension, - we will project to the output. subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head @@ -47,7 +45,6 @@ class Conformer(EncoderInterface): def __init__( self, num_features: int, - output_dim: int, subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, @@ -83,11 +80,6 @@ class Conformer(EncoderInterface): ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - if output_dim == d_model: - self.encoder_output_layer = nn.Identity() - else: - self.encoder_output_layer = ScaledLinear(d_model, output_dim, - initial_speed=0.5) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -123,7 +115,6 @@ class Conformer(EncoderInterface): x = self.encoder(x, pos_emb, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) - x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return x, lengths @@ -1116,7 +1107,7 @@ class Noam(object): if __name__ == '__main__': feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index a442feeea..25a36223d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -46,8 +46,8 @@ class Decoder(nn.Module): Args: vocab_size: Number of tokens of the modeling unit including blank. - embedding_dim: - Dimension of the input embedding. + decoder_dim: + Dimension of the input embedding, and of the decoder output. blank_id: The ID of the blank symbol. context_size: @@ -63,23 +63,22 @@ class Decoder(nn.Module): initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, - embedding_dim=embedding_dim, + embedding_dim=decoder_dim, padding_idx=blank_id, initial_speed=initial_speed ) self.blank_id = blank_id - self.output_linear = ScaledLinear(embedding_dim, embedding_dim) assert context_size >= 1, context_size self.context_size = context_size self.vocab_size = vocab_size if context_size > 1: self.conv = ScaledConv1d( - in_channels=embedding_dim, - out_channels=embedding_dim, + in_channels=decoder_dim, + out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=embedding_dim, + groups=decoder_dim, bias=False, ) @@ -92,7 +91,7 @@ class Decoder(nn.Module): True to left pad the input. Should be True during training. False to not pad the input. Should be False during inference. Returns: - Return a tensor of shape (N, U, embedding_dim). + Return a tensor of shape (N, U, decoder_dim). """ y = y.to(torch.int64) embedding_out = self.embedding(y) @@ -108,5 +107,5 @@ class Decoder(nn.Module): assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = self.output_linear(F.relu(embedding_out)) + embedding_out = F.relu(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index b9c465398..64752b9a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -20,11 +20,19 @@ import torch.nn.functional as F from scaling import ScaledLinear class Joiner(nn.Module): - def __init__(self, input_dim: int, output_dim: int): + def __init__(self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int): super().__init__() - self.output_linear = ScaledLinear(input_dim, output_dim, - initial_speed=0.5) + # We don't bother giving the 'initial_speed' arg to the decoder + # submodules, because it does not affect the initial convergence of the + # system (only the simple joiner is involved in that). + self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) + self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) + self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -41,7 +49,7 @@ class Joiner(nn.Module): assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape == decoder_out.shape - logit = encoder_out + decoder_out + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 0355c4531..5d4c32ac4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -34,23 +34,25 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, - embedding_dim: int, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, vocab_size: int ): """ Args: encoder: It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, C) and + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and `logit_lens` of shape (N,). decoder: It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, C). It should contain + is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, C) and (N, U, C). Its - output shape is (N, T, U, C). Note that its output contains + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its + output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() @@ -61,17 +63,10 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size, + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size, + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, initial_speed=0.5) - with torch.no_grad(): - # Initialize the two projections to be the same; this will be - # convenient for the real joiner, which adds the endcoder - # (acoustic-model/am) and decoder (language-model/lm) embeddings - self.simple_lm_proj.weight[:] = self.simple_am_proj.weight - self.simple_lm_proj.bias[:] = self.simple_am_proj.bias - def forward( self, @@ -133,7 +128,7 @@ class Transducer(nn.Module): # sos_y_padded: [B, S + 1], start with SOS. sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - # decoder_out: [B, S + 1, C] + # decoder_out: [B, S + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded) # Note: y does not start with SOS @@ -167,13 +162,13 @@ class Transducer(nn.Module): s_range=prune_range, ) - # am_pruned : [B, T, prune_range, C] - # lm_pruned : [B, T, prune_range, C] + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( am=encoder_out, lm=decoder_out, ranges=ranges ) - # logits : [B, T, prune_range, C] + # logits : [B, T, prune_range, vocab_size] logits = self.joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c716d457a..a027a5adc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -268,7 +268,7 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - attention_dim: Hidden dim for multi-head attention model. + - encoder_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. @@ -287,12 +287,14 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, + "encoder_dim": 512, "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, # parameters for decoder - "embedding_dim": 512, + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, # parameters for Noam "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 4000, # arg given to model, not for lrate @@ -309,7 +311,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_features=params.feature_dim, output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, + d_model=params.encoder_dim, nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, @@ -329,8 +331,10 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.embedding_dim, - output_dim=params.vocab_size, + encoder_dim=params.encoder_dim + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, ) return joiner @@ -344,7 +348,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - embedding_dim=params.embedding_dim, + encoder_dim=params.encoder_dim + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return model @@ -748,7 +754,7 @@ def run(rank, world_size, args): optimizer = Noam( model.parameters(), - model_size=params.attention_dim, + model_size=params.encoder_dim, factor=params.lr_factor, warm_step=params.warm_step, ) From 34500afc43173444309902fb9aea2d6ad2b15d38 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:06:43 +0800 Subject: [PATCH 152/234] Various bug fixes --- .../ASR/pruned_transducer_stateless2/decoder.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 25a36223d..3291ad877 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -38,7 +38,7 @@ class Decoder(nn.Module): def __init__( self, vocab_size: int, - embedding_dim: int, + decoder_dim: int, blank_id: int, context_size: int, ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a027a5adc..e8fbb6a71 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -309,7 +309,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, d_model=params.encoder_dim, nhead=params.nhead, @@ -322,7 +321,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, + decoder_dim=params.decoder_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -331,7 +330,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=params.encoder_dim + encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -348,7 +347,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=params.encoder_dim + encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, From 807fcada683ab96aaa585427cc49ce4c21522146 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 20:15:11 +0800 Subject: [PATCH 153/234] Change learning speed of simple_lm_proj --- egs/librispeech/ASR/pruned_transducer_stateless2/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 5d4c32ac4..1dd20c546 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -65,8 +65,7 @@ class Transducer(nn.Module): self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, - initial_speed=0.5) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( self, From 9f62a0296cd072083399f6862d1df6bee0134555 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 2 Apr 2022 21:16:39 +0800 Subject: [PATCH 154/234] Revert transducer_stateless/ to state in upstream/master --- .../ASR/transducer_stateless/conformer.py | 396 +++++------------- .../ASR/transducer_stateless/decoder.py | 144 +------ .../transducer_stateless/encoder_interface.py | 2 +- .../ASR/transducer_stateless/joiner.py | 4 +- .../ASR/transducer_stateless/model.py | 3 +- .../ASR/transducer_stateless/train.py | 9 +- .../ASR/transducer_stateless/transformer.py | 4 +- 7 files changed, 108 insertions(+), 454 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index ae95d95b4..488c82386 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,8 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple, Sequence -from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from typing import Optional, Tuple import torch from torch import Tensor, nn @@ -57,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -82,13 +80,17 @@ class Conformer(Transformer): cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.normalize_before = normalize_before - + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -115,8 +117,10 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup_mode=warmup_mode) # (T, N, C) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + + if self.normalize_before: + x = self.after_norm(x) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -154,41 +158,42 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), + nn.Linear(d_model, dim_feedforward), + Swish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + nn.Linear(dim_feedforward, d_model), ) self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), + nn.Linear(d_model, dim_feedforward), + Swish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + nn.Linear(dim_feedforward, d_model), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - self.norm_final = BasicNorm(d_model) + self.ff_scale = 0.5 - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_positive=6.0) + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) + self.normalize_before = normalize_before def forward( self, @@ -215,10 +220,19 @@ class ConformerEncoderLayer(nn.Module): """ # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) src_att = self.self_attn( src, src, @@ -227,15 +241,28 @@ class ConformerEncoderLayer(nn.Module): attn_mask=src_mask, key_padding_mask=src_key_padding_mask, )[0] - src = src + self.dropout(src_att) + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) # convolution module - src = src + self.dropout(self.conv_module(src)) + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) # feed forward module - src = src + self.dropout(self.feed_forward(src)) + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) - src = self.norm_final(self.balancer(src)) + if self.normalize_before: + src = self.norm_final(src) return src @@ -255,20 +282,12 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int, - aux_layers: Sequence[int]) -> None: + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) - self.aux_layers = set(aux_layers + [num_layers - 1]) - assert num_layers - 1 not in aux_layers self.num_layers = num_layers - num_channels = encoder_layer.d_model - self.combiner = RandomCombine(num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0) def forward( self, @@ -276,7 +295,6 @@ class ConformerEncoder(nn.Module): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup_mode: bool = False ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -296,19 +314,14 @@ class ConformerEncoder(nn.Module): """ output = src - outputs = [] - - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) - if i in self.aux_layers: - outputs.append(output) - output = self.combiner(outputs, warmup_mode) return output @@ -331,6 +344,7 @@ class RelPositionalEncoding(torch.nn.Module): """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model + self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -382,6 +396,7 @@ class RelPositionalEncoding(torch.nn.Module): """ self.extend_pe(x) + x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 @@ -413,7 +428,6 @@ class RelPositionMultiheadAttention(nn.Module): embed_dim: int, num_heads: int, dropout: float = 0.0, - scale_speed: float = 5.0 ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -424,29 +438,25 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) # these two learnable bias are used in matrix c and matrix d # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.scale_speed = scale_speed - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() - def _pos_bias_u(self): - return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp() - - def _pos_bias_v(self): - return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp() - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.05) - nn.init.normal_(self.pos_bias_v, std=0.05) + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) def forward( self, @@ -506,11 +516,11 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb, self.embed_dim, self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), + self.in_proj.weight, + self.in_proj.bias, self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), + self.out_proj.weight, + self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -614,12 +624,13 @@ class RelPositionMultiheadAttention(nn.Module): assert ( head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -651,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -670,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -720,7 +729,7 @@ class RelPositionMultiheadAttention(nn.Module): ) key_padding_mask = key_padding_mask.to(torch.bool) - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) k = k.contiguous().view(-1, bsz, num_heads, head_dim) v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) @@ -741,11 +750,11 @@ class RelPositionMultiheadAttention(nn.Module): p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - q_with_bias_u = (q + self._pos_bias_u()).transpose( + q_with_bias_u = (q + self.pos_bias_u).transpose( 1, 2 ) # (batch, head, time1, d_k) - q_with_bias_v = (q + self._pos_bias_v()).transpose( + q_with_bias_v = (q + self.pos_bias_v).transpose( 1, 2 ) # (batch, head, time1, d_k) @@ -765,7 +774,7 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = ( matrix_ac + matrix_bd - ) # (batch, head, time1, time2) + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 @@ -840,7 +849,7 @@ class ConvolutionModule(nn.Module): # kernerl_size should be a odd number for 'SAME' padding assert (kernel_size - 1) % 2 == 0 - self.pointwise_conv1 = ScaledConv1d( + self.pointwise_conv1 = nn.Conv1d( channels, 2 * channels, kernel_size=1, @@ -848,25 +857,7 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, - max_positive=1.0) - - self.depthwise_conv = ScaledConv1d( + self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, @@ -875,22 +866,16 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - - self.deriv_balancer2 = ActivationBalancer(channel_dim=1, - min_positive=0.05, - max_positive=1.0) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, - initial_scale=0.25 ) + self.activation = Swish() def forward(self, x: Tensor) -> Tensor: """Compute convolution module. @@ -907,14 +892,15 @@ class ConvolutionModule(nn.Module): # GLU mechanism x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) - x = self.deriv_balancer2(x) x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) @@ -922,197 +908,13 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) -class Identity(torch.nn.Module): +class Swish(torch.nn.Module): + """Construct an Swish object.""" + def forward(self, x: Tensor) -> Tensor: - return x + """Return Swich activation function.""" + return x * torch.sigmoid(x) -class RandomCombine(torch.nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - The idea is that the list of Tensors will be a list of outputs of multiple - conformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - def __init__(self, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0) -> None: - """ - Args: - num_inputs: The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: A standard deviation that we add to log-probs for computing - randomized weights. - - The method of choosing which layers, - or combinations of layers, to use, is conceptually as follows. - With probability `pure_prob`: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). - """ - super(RandomCombine, self).__init__() - assert pure_prob >= 0 and pure_prob <= 1 - assert final_weight > 0 and final_weight < 1 - assert num_inputs >= 1 - - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev= stddev - - self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() - - - def forward(self, inputs: Sequence[Tensor], - warmup_mode: bool) -> Tensor: - """ - Forward function. - Args: - inputs: a list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - a Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.num_inputs - assert len(inputs) == num_inputs - if not (self.training and warmup_mode): - return inputs[-1] - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames, - num_channels, - num_inputs)) - - # weights: (num_frames, num_inputs) - weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, - num_frames) - - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: - """ - Return a tensor of random weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), such that - ans.sum(dim=1) is all ones. - - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) - - def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with - exactly one weight equal to 1.0 on each frame. - """ - - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) - - indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, - final, nonfinal) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) - return ans - - - def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): - """ - Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), - Args: - dtype: the data-type desired for the answer, e.g. float, double - device: the device needed for the answer - num_frames: the number of sets of weights desired - Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that - sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. - """ - logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev - logprobs[:,-1] += self.final_log_weight - return logprobs.softmax(dim=1) - - -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") - num_inputs = 3 - num_channels = 50 - m = RandomCombine(num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev) - - x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] - - y = m(x, True) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. - - -if __name__ == '__main__': - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup_mode=True) +def identity(x): + return x diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index db51fb1cd..b82fed37b 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -17,9 +17,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional -from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -55,7 +52,7 @@ class Decoder(nn.Module): 1 means bigram; 2 means trigram. n means (n+1)-gram. """ super().__init__() - self.embedding = ScaledEmbedding( + self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=blank_id, @@ -65,7 +62,7 @@ class Decoder(nn.Module): assert context_size >= 1, context_size self.context_size = context_size if context_size > 1: - self.conv = ScaledConv1d( + self.conv = nn.Conv1d( in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=context_size, @@ -85,7 +82,6 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - y = y.to(torch.int64) embedding_out = self.embedding(y) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) @@ -100,139 +96,3 @@ class Decoder(nn.Module): embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) return embedding_out - - - -class ScaledEmbedding(nn.Module): - r"""A simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' - elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters() - - - - def reset_parameters(self) -> None: - nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - scale = (self.scale * self.scale_speed).exp() - if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale - else: - return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) - - def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' - if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' - if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' - if self.sparse is not False: - s += ', sparse=True' - return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py index 3d218dcd0..257facce4 100644 --- a/egs/librispeech/ASR/transducer_stateless/encoder_interface.py +++ b/egs/librispeech/ASR/transducer_stateless/encoder_interface.py @@ -22,7 +22,7 @@ import torch.nn as nn class EncoderInterface(nn.Module): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index 241f405b6..b0ba7fd83 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from subsampling import ScaledLinear + class Joiner(nn.Module): def __init__(self, input_dim: int, output_dim: int): @@ -24,7 +24,7 @@ class Joiner(nn.Module): self.input_dim = input_dim self.output_dim = output_dim - self.output_linear = ScaledLinear(input_dim, output_dim) + self.output_linear = nn.Linear(input_dim, output_dim) def forward( self, diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index fc16f2631..8281e1fb5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -65,7 +65,6 @@ class Transducer(nn.Module): x_lens: torch.Tensor, y: k2.RaggedTensor, modified_transducer_prob: float = 0.0, - warmup_mode: bool = False ) -> torch.Tensor: """ Args: @@ -88,7 +87,7 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index fa0410973..d6827c17c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -111,8 +111,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - # was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization.. - default="transducer_stateless/randcombine1_expscale3_rework2d", + default="transducer_stateless/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -223,7 +222,6 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 - "warmup_minibatches": 3000, # use warmup mode for 3k minibatches. # parameters for conformer "feature_dim": 80, "encoder_out_dim": 512, @@ -381,7 +379,6 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - is_warmup_mode: bool = False ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -418,7 +415,6 @@ def compute_loss( x_lens=feature_lens, y=y, modified_transducer_prob=params.modified_transducer_prob, - warmup_mode=is_warmup_mode ) assert loss.requires_grad == is_training @@ -455,7 +451,6 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, - is_warmup_mode=False ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -517,7 +512,6 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - is_warmup_mode=(params.batch_idx_train Date: Mon, 4 Apr 2022 13:34:43 +0800 Subject: [PATCH 155/234] Fix to joiner to allow different dims --- egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 64752b9a0..a1226f712 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -47,7 +47,7 @@ class Joiner(nn.Module): Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 - assert encoder_out.shape == decoder_out.shape + assert encoder_out.shape[:-1] == decoder_out.shape[:-1] logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) From 99e9d6c4b8ab035d9c1962fc5b6086586d336090 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 13:37:10 +0800 Subject: [PATCH 156/234] Some cleanups --- .../ASR/conformer_ctc/subsampling.py | 422 +----------------- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 - .../ASR/transducer_stateless/diagnostics.py | 338 -------------- 3 files changed, 5 insertions(+), 757 deletions(-) delete mode 100644 egs/librispeech/ASR/transducer_stateless/diagnostics.py diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 0a39b0f33..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -17,8 +17,6 @@ import torch import torch.nn as nn -from torch import Tensor -from typing import Tuple class Conv2dSubsampling(nn.Module): @@ -44,27 +42,16 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - ScaledConv2d( + nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( + nn.ReLU(), + nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), + nn.ReLU(), ) - self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -83,8 +70,6 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) return x @@ -174,400 +159,3 @@ class VggSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) return x - - - - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) - - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) - - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_speed: a constant that determines how fast "eps" learns; - with Adam and variants, this should probably be >= 1, - e.g. 5.0. For SGD and variants, probably a value less than one, - like 0.1, would be suitable, to prevent instability. - """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_speed: float = 5.0): - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_speed = eps_speed - if learn_eps: - self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) - else: - self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) - - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - (self.eps * self.eps_speed).exp()) ** -0.5 - return x * scales - - - - -class ScaledLinear(nn.Linear): - """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * (self.weight_scale * self.scale_speed).exp() - bias = self.bias * (self.bias_scale * self.scale_speed).exp() - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - scale_speed: a factor that affects how fast the weight_scale - and bias_scale learn; this value is suitable for Adam-type - optimizers. - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - - Note: it uses the default initialization for the weight and bias, - inherited from nn.Linear. For modules with small fan-in, this - may be larger than optimal. - """ - def __init__(self, *args, - scale_speed: float = 5.0, - initial_scale: float = 1.0, - **kwargs): - super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - self.scale_speed = scale_speed - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - - self._reset_parameters() # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) - - -class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, - initial_scale=1.0, **kwargs): - super(ScaledConv1d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - - - -class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): - super(ScaledConv2d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - - Args: - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): - super(ActivationBalancer, self).__init__() - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - - def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) - - -def _double_swish(x: Tensor) -> Tensor: - # double-swish, implemented/approximated as offset-swish - return x * torch.sigmoid(x - 1.0) - -class DoubleSwishFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x.detach()) - return _double_swish(x) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - # TODO: can make this more efficient. - x, = ctx.saved_tensors - x.requires_grad = True - with torch.enable_grad(): - y = _double_swish(x) - y.backward(gradient=y_grad) - return x.grad - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - return DoubleSwishFunction.apply(x) - - - -def _test_deriv_balancer_sign(): - channel_dim = 0 - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_deriv_balancer_sign: x = ", x) - print("_test_deriv_balancer_sign: y grad = ", y_grad) - print("_test_deriv_balancer_sign: x grad = ", x.grad) - -def _test_deriv_balancer_magnitude(): - channel_dim = 0 - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_deriv_balancer_magnitude: x = ", x) - print("_test_deriv_balancer_magnitude: y grad = ", y_grad) - print("_test_deriv_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - - - - -if __name__ == '__main__': - _test_deriv_balancer_sign() - _test_deriv_balancer_magnitude() - _test_basic_norm() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 477afcecb..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -22,8 +22,6 @@ import logging from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional -import torch -from lhotse.utils import fix_random_seed import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py deleted file mode 100644 index 7fd83d56b..000000000 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ /dev/null @@ -1,338 +0,0 @@ -import torch -from torch import Tensor -from torch import nn -import math -import random -from typing import Tuple, List - - -class TensorDiagnosticOptions(object): - """ - Options object for tensor diagnostics: - - Args: - memory_limit: the maximum number of bytes we store per tensor (limits how many copies - of the tensor we cache). - max_eig_dim: the maximum dimension for which we print out eigenvalues - (limited for speed reasons). - """ - def __init__(self, - memory_limit: int = (2 ** 20), - max_eig_dim: int = 512): - - self.memory_limit = memory_limit - self.max_eig_dim = max_eig_dim - - def dim_is_summarized(self, size: int): - return size > 10 and size != 31 - - - -def get_tensor_stats(x: Tensor, dim: int, - stats_type: str) -> Tuple[Tensor, int]: - """ - Returns the specified transformation of the Tensor (either x or x.abs() - or (x > 0), summed over all but the index `dim`. - - Args: - x: Tensor, tensor to be analyzed - dim: dimension with 0 <= dim < x.ndim - stats_type: - "abs" -> take abs() before summing - "positive" -> take (x > 0) before summing - "rms" -> square before summing, we'll take sqrt later - "value -> just sum x itself - Returns (stats, count) - where stats is a Tensor of shape (x.shape[dim],), and the count - is an integer saying how many items were counted in each element - of stats. - """ - count = x.numel() // x.shape[dim] - - if stats_type == "eigs": - x = x.transpose(dim, -1) - x = x.reshape(-1, x.shape[-1]) - # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. - return torch.matmul(x.transpose(0, 1), x), count - elif stats_type == "abs": - x = x.abs() - elif stats_type == "rms": - x = x ** 2 - elif stats_type == "positive": - x = (x > 0).to(dtype=torch.float) - else: - assert stats_type == "value" - - sum_dims = [ d for d in range(x.ndim) if d != dim ] - if len(sum_dims) > 0: - x = torch.sum(x, dim=sum_dims) - x = x.flatten() - return x, count - -def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions, - sizes_same: bool, - stats_type: str): - """ - This function gets diagnostics for a dimension of a module. - Args: - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: options object - sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", - imdictates the type of stats - we accumulate, abs is mean absolute value, "positive" - is proportion of positive to nonnegative values, "eigs" - is eigenvalues after doing outer product on this dim, sum - over all other dimes. - Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. Will return the empty string if the diagnostics did - not make sense to print out for this dimension, e.g. dimension - mismatch and stats_type == "eigs" - """ - # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] - stats = [ x[0] for x in stats_and_counts ] - counts = [ x[1] for x in stats_and_counts ] - - if stats_type == "eigs": - try: - stats = torch.stack(stats).sum(dim=0) - except: - return '' - count = sum(counts) - stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - elif sizes_same: - stats = torch.stack(stats).sum(dim=0) - count = sum(counts) - stats = stats / count - else: - stats = [ x[0] / x[1] for x in stats_and_counts ] - stats = torch.cat(stats, dim=0) - if stats_type == 'rms': - stats = stats.sqrt() - - # if `summarize` we print percentiles of the stats; else, - # we print out individual elements. - summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) - if summarize: - # print out percentiles. - stats = stats.sort()[0] - num_percentiles = 10 - size = stats.numel() - percentiles = [] - for i in range(num_percentiles + 1): - index = (i * (size - 1)) // num_percentiles - percentiles.append(stats[index].item()) - percentiles = [ '%.2g' % x for x in percentiles ] - percentiles = ' '.join(percentiles) - ans = f'percentiles: [{percentiles}]' - else: - ans = stats.tolist() - ans = [ '%.2g' % x for x in ans ] - ans = '[' + ' '.join(ans) + ']' - if stats_type == "value": - # This norm is useful because it is strictly less than the largest - # sqrt(eigenvalue) of the variance, which we print out, and shows, - # speaking in an approximate way, how much of that largest eigenvalue - # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' - else: - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f', mean={mean:.2g}, rms={rms:.2g}' - return ans - - - -def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions): - ndim = tensors[0].ndim - if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] - if tensors[0].shape[dim] <= options.max_eig_dim: - stats_types.append("eigs") - else: - stats_types = [ "value", "abs" ] - - for stats_type in stats_types: - sizes = [ x.shape[dim] for x in tensors ] - sizes_same = all([ x == sizes[0] for x in sizes ]) - s = get_diagnostics_for_dim(dim, tensors, - options, sizes_same, - stats_type) - if s == '': - continue - - min_size = min(sizes) - max_size = max(sizes) - size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "abs" or "positive". - print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") - - -class TensorDiagnostic(object): - """ - This class is not directly used by the user, it is responsible for collecting - diagnostics for a single parameter tensor of a torch.Module. - """ - def __init__(self, - opts: TensorDiagnosticOptions, - name: str): - self.name = name - self.opts = opts - self.saved_tensors = [] - - def accumulate(self, x): - if isinstance(x, Tuple): - x = x[0] - if not isinstance(x, Tensor): - return - if x.device == torch.device('cpu'): - x = x.detach().clone() - else: - x = x.detach().to('cpu', non_blocking=True) - self.saved_tensors.append(x) - l = len(self.saved_tensors) - if l & (l - 1) == 0: # power of 2.. - self._limit_memory() - - def _limit_memory(self): - if len(self.saved_tensors) > 1024: - self.saved_tensors = self.saved_tensors[-1024:] - return - - tot_mem = 0.0 - for i in reversed(range(len(self.saved_tensors))): - tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() - if tot_mem > self.opts.memory_limit: - self.saved_tensors = self.saved_tensors[i:] - return - - def print_diagnostics(self): - if len(self.saved_tensors) == 0: - print("{name}: no stats".format(name=self.name)) - return - if self.saved_tensors[0].ndim == 0: - # ensure there is at least one dim. - self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] - - try: - device = torch.device('cuda') - torch.ones(1, 1, device) - except: - device = torch.device('cpu') - - ndim = self.saved_tensors[0].ndim - tensors = [x.to(device) for x in self.saved_tensors] - for dim in range(ndim): - print_diagnostics_for_dim(self.name, dim, - tensors, - self.opts) - - -class ModelDiagnostic(object): - def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): - self.diagnostics = dict() - self.opts = opts - - def __getitem__(self, name: str): - if name not in self.diagnostics: - self.diagnostics[name] = TensorDiagnostic(self.opts, name) - return self.diagnostics[name] - - def print_diagnostics(self): - for k in sorted(self.diagnostics.keys()): - self.diagnostics[k].print_diagnostics() - - - -def attach_diagnostics(model: nn.Module, - opts: TensorDiagnosticOptions) -> ModelDiagnostic: - ans = ModelDiagnostic(opts) - for name, module in model.named_modules(): - if name == '': - name = "" - forward_diagnostic = TensorDiagnostic(opts, name + ".output") - backward_diagnostic = TensorDiagnostic(opts, name + ".grad") - - - # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, - # ensures that we use the current values. (matters for name, since - # the variable gets overwritten). these closures don't really capture - # by value, only by "the final value the variable got in the function" :-( - def forward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): - if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate(_output) - elif isinstance(_output, tuple): - for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) - - def backward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): - if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate(_output) - elif isinstance(_output, tuple): - for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) - - module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) - - for name, parameter in model.named_parameters(): - - def param_backward_hook(grad, - _parameter=parameter, - _model_diagnostic=ans, - _name=name): - _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) - _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) - - parameter.register_hook(param_backward_hook) - return ans - - - -def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, 512) - - diagnostic = TensorDiagnostic(opts, "foo") - - for _ in range(10): - diagnostic.accumulate(torch.randn(50, 100) * 10.0) - - diagnostic.print_diagnostics() - - model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) - - diagnostic = attach_diagnostics(model, opts) - for _ in range(10): - T = random.randint(200, 300) - x = torch.randn(T, 100) - y = model(x) - y.sum().backward() - - diagnostic.print_diagnostics() - - - -if __name__ == '__main__': - _test_tensor_diagnostic() - - -def _test_func(): - ans = [] - for i in range(10): - x = list() - x.append(i) - def func(): - return x - ans.append(func) - return ans From a5bbcd7b71a9f519e8f6d7830e8265a2d1fc490c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 14:10:38 +0800 Subject: [PATCH 157/234] Make training more efficient, avoid redoing some projections. --- .../ASR/pruned_transducer_stateless2/joiner.py | 12 ++++++++++-- .../ASR/pruned_transducer_stateless2/model.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index a1226f712..752a5f774 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -35,7 +35,8 @@ class Joiner(nn.Module): self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, + project_input: bool = True ) -> torch.Tensor: """ Args: @@ -43,13 +44,20 @@ class Joiner(nn.Module): Output from the encoder. Its shape is (N, T, s_range, C). decoder_out: Output from the decoder. Its shape is (N, T, s_range, C). + project_input: + If true, apply input projections encoder_proj and decoder_proj. + If this is false, it is the user's responsibility to do this + manually. Returns: Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + if project_input: + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + else: + logit = encoder_out + decoder_out logit = self.output_linear(torch.tanh(logit)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 1dd20c546..a9178c8b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -164,11 +164,17 @@ class Transducer(nn.Module): # am_pruned : [B, T, prune_range, encoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim] am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=encoder_out, lm=decoder_out, ranges=ranges + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges ) # logits : [B, T, prune_range, vocab_size] - logits = self.joiner(am_pruned, lm_pruned) + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, + project_input=False) pruned_loss = k2.rnnt_loss_pruned( logits=logits, From 4929e4cf32f93860aad273223c86a6dc98d611df Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 17:09:25 +0800 Subject: [PATCH 158/234] Change how warm-step is set --- .../ASR/pruned_transducer_stateless2/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index e8fbb6a71..bf7f23fab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -147,6 +147,13 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--warm-step", + type=float, + default=60000, + help="The number of warmup steps for the (modified) Noam optimizer", + ) + parser.add_argument( "--context-size", type=int, @@ -296,7 +303,6 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "warm_step": 60000, # For the 100h subset, use 8k "model_warm_step": 4000, # arg given to model, not for lrate "env_info": get_env_info(), } @@ -709,7 +715,6 @@ def run(rank, world_size, args): params.update(vars(args)) if params.full_libri is False: params.valid_interval = 1600 - params.warm_step = 30000 fix_random_seed(params.seed) if world_size > 1: From 72f4a673b106fefc9c88841a3e4ff3a9d1d6fd88 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 20:21:34 +0800 Subject: [PATCH 159/234] First draft of new approach to learning rates + init --- .../pruned_transducer_stateless2/conformer.py | 87 ------ .../ASR/pruned_transducer_stateless2/optim.py | 254 ++++++++++++++++++ .../pruned_transducer_stateless2/scaling.py | 12 +- .../ASR/pruned_transducer_stateless2/train.py | 50 +++- 4 files changed, 299 insertions(+), 104 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless2/optim.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 0deb960ad..4797cce08 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1017,93 +1017,6 @@ class Conv2dSubsampling(nn.Module): return x -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - if __name__ == '__main__': feature_dim = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py new file mode 100644 index 000000000..edbebcceb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -0,0 +1,254 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class Eve(Optimizer): + r""" + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular specified value (generally 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, + target_rms=0.1): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict(lr=lr, betas=betas, eps=eps, + target_rms=target_rms) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + target_rms = group['target_rms'] + delta = exp_avg / denom + + # we'll be doing: p += delta * step_size. + # In the normal case delta_rms (the rms value of the elements of + # delta) will be very close to 1.0, but we compute it here so + # that if we don't use a particular parameter, its value won't + # shrink to zero. + # delta_var is the expected change in the variance of the parameter + # values, i.e. of E[param_elem^2], due to this step. It will + # be close to 1. + + # Let us define: + # delta_var_from_update = (delta**2).mean() * step_size * step_size + + # Suppose we are going to shrinkage with a small value epsilon (not the + # same as the eps above!), i.e. param *= (1-epsilon). Then + # if E[param_elem^2] == target_rms^2, + # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2), + # which we can put as: + # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. + # Setting delta_var_from_shrinkage = -delta_var_from_update + # because we want them to cancel, + # delta_var_from_update = 2 epsilon target_rms^2, or: + # epsilon = delta_var_from_update / (2 * target_rms^2) + # = (delta**2).mean() * 0.5 * (step_size / target_rms)**2. + # Note: step_size is close to the learning rate. For an example, if + # lr = 1.0e-04 and target_rms == 0.1, then in the normal case where + # (delta**2).mean() == 1, we will have: + # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. + # Note that this is close to the "traditional" value used for weight + # decay. + + # this is the weight-decay amount... + weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) + + p.mul_(1 - weight_decay) + p.add_(delta, alpha=-step_size) + + return loss + + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * self.warmup ** (-0.5 - -0.333) + * min(step ** (-0.333), step * self.warmup ** (-1.333)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 4c45205ce..33b4ad908 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -158,7 +158,10 @@ class ScaledLinear(nn.Linear): self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + # we plan to use Eve as the optimizer, which will eventually make the stddev approach + # 0.1 as that's the target_rms we set, but we initialize with a larger stddev + # to have the same effect as a warm-up period. + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -196,7 +199,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -241,7 +244,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.01 / initial_speed + std = 0.5 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -476,9 +479,8 @@ class ScaledEmbedding(nn.Module): self.reset_parameters(initial_speed) - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.01 / initial_speed + std = 0.5 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index bf7f23fab..9d074fdd4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -28,7 +28,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ - --lr-factor 1.5 + --initial-lr 0.002 \ + --lr-decay-steps 10000 \ + --num-lr-decays 4 + """ @@ -52,6 +55,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from optim import Eve from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -141,17 +145,24 @@ def get_parser(): ) parser.add_argument( - "--lr-factor", + "--initial-lr", type=float, - default=5.0, - help="The lr_factor for Noam optimizer", + default=0.002, + help="The initial learning rate", ) parser.add_argument( - "--warm-step", + "--lr-decay-steps", type=float, - default=60000, - help="The number of warmup steps for the (modified) Noam optimizer", + default=5000, + help="The number of steps before we decay (halve) the learning rate", + ) + + parser.add_argument( + "--num-lr-decays", + type=float, + default=4, + help="The total number of times we decay (halve) the learning rate" ) parser.add_argument( @@ -426,6 +437,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -449,6 +461,7 @@ def save_checkpoint( model=model, params=params, optimizer=optimizer, + scheduler=scheduler, sampler=sampler, rank=rank, ) @@ -574,6 +587,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -594,6 +608,8 @@ def train_one_epoch( The model for training. optimizer: The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. train_dl: Dataloader for the training dataset. valid_dl: @@ -636,6 +652,7 @@ def train_one_epoch( loss.backward() optimizer.step() optimizer.zero_grad() + lr_scheduler.step() if params.print_diagnostics and batch_idx == 5: return @@ -651,6 +668,7 @@ def train_one_epoch( model=model, params=params, optimizer=optimizer, + scheduler=scheduler, sampler=train_dl.sampler, rank=rank, ) @@ -756,17 +774,24 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - optimizer = Noam( + optimizer = Eve( model.parameters(), - model_size=params.encoder_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - ) + lr=params.initial_lr, betas=(0.9, 0.98), + eps=1e-9, target_rms=0.1) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + [ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], + gamma=0.5) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) + if checkpoints and "scheduler" in checkpoints: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( @@ -839,6 +864,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sp=sp, train_dl=train_dl, valid_dl=valid_dl, From d1f2f934605cebe7e438f483e4f7beee6bf0966e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 22:40:18 +0800 Subject: [PATCH 160/234] Some fixes.. --- .../ASR/pruned_transducer_stateless2/optim.py | 97 +------------------ .../ASR/pruned_transducer_stateless2/train.py | 12 ++- 2 files changed, 12 insertions(+), 97 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index edbebcceb..6f19807dc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -27,7 +27,7 @@ class Eve(Optimizer): r""" Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular specified value (generally 0.1). This is + rms of the parameters approach a particular specified value (we suggest 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. @@ -120,7 +120,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) step_size = group['lr'] / bias_correction1 target_rms = group['target_rms'] @@ -141,7 +141,7 @@ class Eve(Optimizer): # Suppose we are going to shrinkage with a small value epsilon (not the # same as the eps above!), i.e. param *= (1-epsilon). Then # if E[param_elem^2] == target_rms^2, - # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1- 2epsilon + epsilon^2), + # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), # which we can put as: # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. # Setting delta_var_from_shrinkage = -delta_var_from_update @@ -157,98 +157,9 @@ class Eve(Optimizer): # decay. # this is the weight-decay amount... - weight_decay = (delta ** 2).mean().sqrt() * ((0.5 * (step_size / target_rms)) ** 2) + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) return loss - - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * self.warmup ** (-0.5 - -0.333) - * min(step ** (-0.333), step * self.warmup ** (-1.333)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 9d074fdd4..9f73c8fbc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -48,7 +48,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer, Noam +from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -437,7 +437,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optimal.lr_scheduler._LRScheduler] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -652,7 +652,7 @@ def train_one_epoch( loss.backward() optimizer.step() optimizer.zero_grad() - lr_scheduler.step() + scheduler.step() if params.print_diagnostics and batch_idx == 5: return @@ -848,7 +848,7 @@ def run(rank, world_size, args): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - cur_lr = optimizer._rate + cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar( "train/learning_rate", cur_lr, params.batch_idx_train @@ -908,12 +908,16 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. loss, _ = compute_loss( params=params, model=model, sp=sp, batch=batch, is_training=True, + warmup = 0.0 ) loss.backward() optimizer.step() From 179d0605ea235fa92fee7289a3db1374f3ec2bcf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 23:34:39 +0800 Subject: [PATCH 161/234] Change initialization to 0.25 --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 3 ++- .../ASR/pruned_transducer_stateless2/scaling.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 6f19807dc..17450def8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -140,7 +140,8 @@ class Eve(Optimizer): # Suppose we are going to shrinkage with a small value epsilon (not the # same as the eps above!), i.e. param *= (1-epsilon). Then - # if E[param_elem^2] == target_rms^2, + # if E[param_elem^2] == target_rms^2 (because we desire equilibrium when + # the RMS of the parameters equals target_rms), it follows that # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), # which we can put as: # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 33b4ad908..4b91bb04c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -161,7 +161,7 @@ class ScaledLinear(nn.Linear): # we plan to use Eve as the optimizer, which will eventually make the stddev approach # 0.1 as that's the target_rms we set, but we initialize with a larger stddev # to have the same effect as a warm-up period. - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -199,7 +199,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -244,7 +244,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.5 / initial_speed + std = 0.25 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -480,7 +480,7 @@ class ScaledEmbedding(nn.Module): def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.5 / initial_speed + std = 0.25 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 234366e51c450cb95e18acdcf9d6544d74155885 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 00:18:36 +0800 Subject: [PATCH 162/234] Fix type of parameter --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 9f73c8fbc..83558a72b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -160,7 +160,7 @@ def get_parser(): parser.add_argument( "--num-lr-decays", - type=float, + type=int, default=4, help="The total number of times we decay (halve) the learning rate" ) From 2b0727a355d73205c1a91b770902c0da04aec958 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 00:31:28 +0800 Subject: [PATCH 163/234] Fix weight decay formula by adding 1/1-beta --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 17450def8..607a4e350 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -156,9 +156,12 @@ class Eve(Optimizer): # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. # Note that this is close to the "traditional" value used for weight # decay. - + # # this is the weight-decay amount... - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2) + # + # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive + # frames being correlated. I have to figure out the justification. + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta))) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) From 47d49f29d78742e9d22850c08bdd094d1c4bb6f9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 00:31:55 +0800 Subject: [PATCH 164/234] Fix weight decay formula by adding 1/1-beta --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 607a4e350..eb7776938 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -161,7 +161,7 @@ class Eve(Optimizer): # # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive # frames being correlated. I have to figure out the justification. - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta))) + weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1))) p.mul_(1 - weight_decay) p.add_(delta, alpha=-step_size) From 1548cc7462a59da00f3bddad7b51166c5a0a3b09 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 11:19:40 +0800 Subject: [PATCH 165/234] Fix checkpoint-writing --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 83558a72b..c63c849c4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -376,6 +376,7 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -395,6 +396,8 @@ def load_checkpoint_if_available( The training model. optimizer: The optimizer that we are using. + scheduler: + The scheduler that we are using. Returns: Return a dict containing previously saved training info. """ @@ -411,6 +414,7 @@ def load_checkpoint_if_available( filename, model=model, optimizer=optimizer, + scheduler=scheduler, ) keys = [ @@ -784,6 +788,7 @@ def run(rank, world_size, args): gamma=0.5) + if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) @@ -792,7 +797,6 @@ def run(rank, world_size, args): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) - if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 2 ** 22 @@ -881,6 +885,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, + scheduler=scheduler, sampler=train_dl.sampler, rank=rank, ) From 0f5957394bd346c9a0207b66110b7a1bce10f643 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 12:58:43 +0800 Subject: [PATCH 166/234] Fix to reading scheudler from optim --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c63c849c4..348e2dd47 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -793,7 +793,7 @@ def run(rank, world_size, args): logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) - if checkpoints and "scheduler" in checkpoints: + if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None: logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) From c3169222aee9db0780379a70f9dea9daf5254d78 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:23:02 +0800 Subject: [PATCH 167/234] Simplified optimizer, rework somet things.. --- .../ASR/pruned_transducer_stateless2/optim.py | 74 ++++++++----------- .../ASR/pruned_transducer_stateless2/train.py | 22 +++--- 2 files changed, 39 insertions(+), 57 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index eb7776938..b17ebba7c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -27,7 +27,7 @@ class Eve(Optimizer): r""" Implements Eve algorithm. This is a modified version of AdamW with a special way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular specified value (we suggest 0.1). This is + rms of the parameters approach a particular target_rms (default: 0.1). This is for use with networks with 'scaled' versions of modules (see scaling.py), which will be close to invariant to the absolute scale on the parameter matrix. @@ -43,10 +43,13 @@ class Eve(Optimizer): running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -57,7 +60,7 @@ class Eve(Optimizer): """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - target_rms=0.1): + weight_decay=3e-4, target_rms=0.1): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -67,9 +70,12 @@ class Eve(Optimizer): raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, target_rms=target_rms) super(Eve, self).__init__(params, defaults) @@ -94,6 +100,9 @@ class Eve(Optimizer): if p.grad is None: continue + + + # Perform optimization step grad = p.grad if grad.is_sparse: @@ -124,46 +133,21 @@ class Eve(Optimizer): step_size = group['lr'] / bias_correction1 target_rms = group['target_rms'] + weight_decay = group['weight_decay'] delta = exp_avg / denom - # we'll be doing: p += delta * step_size. - # In the normal case delta_rms (the rms value of the elements of - # delta) will be very close to 1.0, but we compute it here so - # that if we don't use a particular parameter, its value won't - # shrink to zero. - # delta_var is the expected change in the variance of the parameter - # values, i.e. of E[param_elem^2], due to this step. It will - # be close to 1. - - # Let us define: - # delta_var_from_update = (delta**2).mean() * step_size * step_size - - # Suppose we are going to shrinkage with a small value epsilon (not the - # same as the eps above!), i.e. param *= (1-epsilon). Then - # if E[param_elem^2] == target_rms^2 (because we desire equilibrium when - # the RMS of the parameters equals target_rms), it follows that - # E[(param_elem*(1-epsilon))^2] == target_rms^2 (1 - 2epsilon + epsilon^2), - # which we can put as: - # delta_var_from_shrinkage \simeq -2 epsilon target_rms^2. - # Setting delta_var_from_shrinkage = -delta_var_from_update - # because we want them to cancel, - # delta_var_from_update = 2 epsilon target_rms^2, or: - # epsilon = delta_var_from_update / (2 * target_rms^2) - # = (delta**2).mean() * 0.5 * (step_size / target_rms)**2. - # Note: step_size is close to the learning rate. For an example, if - # lr = 1.0e-04 and target_rms == 0.1, then in the normal case where - # (delta**2).mean() == 1, we will have: - # epsilon = 1.0 * 0.5 * (1.0e-04 / 0.1) = 1.0e-06. - # Note that this is close to the "traditional" value used for weight - # decay. - # - # this is the weight-decay amount... - # - # Regarding the 1/1-beta factor below: this is to compensate for the deltas on successive - # frames being correlated. I have to figure out the justification. - weight_decay = (delta ** 2).mean() * (0.5 * (step_size / target_rms) ** 2 * (1.0 / (1.0 - beta1))) - - p.mul_(1 - weight_decay) - p.add_(delta, alpha=-step_size) + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" (which are scalar). + is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5))) + p.mul_(1 - (weight_decay * is_below_target_rms)) + p.addcdiv_(exp_avg, denom, value=-step_size) return loss + +# Note on avg-change per epoch.. +# suppose epoch is 4k iters. +# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, +# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) +# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. +# +# .. 6e-05 is 1/5 of that... diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 348e2dd47..1340e0950 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -152,17 +152,17 @@ def get_parser(): ) parser.add_argument( - "--lr-decay-steps", + "--lr-num-steps", type=float, - default=5000, - help="The number of steps before we decay (halve) the learning rate", + default=3000, + help="Number of steps before we start to significantly decay the learning rate", ) parser.add_argument( - "--num-lr-decays", - type=int, - default=4, - help="The total number of times we decay (halve) the learning rate" + "--lr-power", + type=float, + default=0.5, + help="Power in LR-setting rule", ) parser.add_argument( @@ -781,12 +781,10 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), lr=params.initial_lr, betas=(0.9, 0.98), - eps=1e-9, target_rms=0.1) - scheduler = torch.optim.lr_scheduler.MultiStepLR( + eps=1e-9, weight_decay=3e-04, target_rms=0.1) + scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - [ n * params.lr_decay_steps for n in range(1, params.num_lr_decays+1) ], - gamma=0.5) - + lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power)) if checkpoints and "optimizer" in checkpoints: From ed8eba91e14f35107fcfe52137015e0806ae6532 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:24:09 +0800 Subject: [PATCH 168/234] Reduce model_warm_step from 4k to 3k --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 1340e0950..45b3ca168 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -314,7 +314,7 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "model_warm_step": 4000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) From d1a669162caf39fe15e318bfd3f51636cc8826bd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:31:52 +0800 Subject: [PATCH 169/234] Fix bug in lambda --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 45b3ca168..3b8f0499f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -784,7 +784,7 @@ def run(rank, world_size, args): eps=1e-9, weight_decay=3e-04, target_rms=0.1) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (params.lr_num_steps/(step + params.lr_num_steps) ** params.lr_power)) + lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) if checkpoints and "optimizer" in checkpoints: From 25724b5ce9f786f644e662de6e2636add523ce89 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 13:49:35 +0800 Subject: [PATCH 170/234] Bug-fix RE sign of target_rms --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index b17ebba7c..2b40dda45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -137,9 +137,10 @@ class Eve(Optimizer): delta = exp_avg / denom if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" (which are scalar). - is_below_target_rms = (p.norm() < (target_rms * (p.numel() ** 0.5))) - p.mul_(1 - (weight_decay * is_below_target_rms)) + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) + p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) return loss @@ -149,5 +150,6 @@ class Eve(Optimizer): # if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, # then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) # = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. +# Suggested lr_schedule? # # .. 6e-05 is 1/5 of that... From 2545237eb3ff801364151cd8a82ed01896445a17 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 5 Apr 2022 18:00:54 +0800 Subject: [PATCH 171/234] Changing initial_speed from 0.25 to 01 --- .../ASR/pruned_transducer_stateless2/scaling.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 4b91bb04c..98a56ce77 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -158,10 +158,7 @@ class ScaledLinear(nn.Linear): self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear def _reset_parameters(self, initial_speed: float): - # we plan to use Eve as the optimizer, which will eventually make the stddev approach - # 0.1 as that's the target_rms we set, but we initialize with a larger stddev - # to have the same effect as a warm-up period. - std = 0.25 / initial_speed + std = 0.1 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -199,7 +196,7 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.25 / initial_speed + std = 0.1 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -244,7 +241,7 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): - std = 0.25 / initial_speed + std = 0.1 / initial_speed a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: @@ -480,7 +477,7 @@ class ScaledEmbedding(nn.Module): def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.25 / initial_speed + std = 0.1 / initial_speed nn.init.normal_(self.weight, std=std) nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) From 022b0f3c558c0653a16833573f7ad19de8266c7b Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 2 Apr 2022 15:01:45 +0800 Subject: [PATCH 172/234] Modify icefall/__init__.py. (#287) * Modify icefall/__init__.py to import common functions defined in icefall/utils.py. * Modify icefall/__init__.py and .flake8. --- .flake8 | 3 ++- icefall/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 229cf1d6c..dd9239b2d 100644 --- a/.flake8 +++ b/.flake8 @@ -13,4 +13,5 @@ per-file-ignores = exclude = .git, **/data/**, - icefall/shared/make_kn_lm.py + icefall/shared/make_kn_lm.py, + icefall/__init__.py diff --git a/icefall/__init__.py b/icefall/__init__.py index e69de29bb..983539d6f 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -0,0 +1,24 @@ +from .utils import ( + AttributeDict, + MetricsTracker, + add_eos, + add_sos, + concat, + encode_supervisions, + get_alignments, + get_executor, + get_texts, + l1_norm, + l2_norm, + linf_norm, + load_alignments, + make_pad_mask, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + save_alignments, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) From b0bce20e21f5e6bd35d5a778b08f1ba0d59ce696 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 2 Apr 2022 16:26:13 +0800 Subject: [PATCH 173/234] Modify subsampling.py to make T'=T//4 strictly --- .../subsampling.py | 166 ++++++++++++++++++ .../test_subsampling.py | 25 +++ 2 files changed, 191 insertions(+) create mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py create mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py new file mode 100644 index 000000000..7d0ad44a6 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py @@ -0,0 +1,166 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' == T // 4. + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >= 4, idim >= 7 + odim: + Output dim. The output shape is (N, T // 4, odim) + """ + assert idim >= 7 + super().__init__() + self.conv_1 = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.conv_2 = nn.Sequential( + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, T // 4, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) + # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = nn.functional.pad(x, (0, 0, 0, 1), "constant", 0) + # x is of shape (N, 1, T + 1, idim) + x = self.conv_1(x) + # Now x is of shape (N, odim, T // 2, (idim - 1) // 2) + x = nn.functional.pad(x, (0, 0, 0, 1), "constant", 0) + # x is of shape (N, odim, T // 2 + 1, (idim - 1) // 2) + x = self.conv_2(x) + # Now x is of shape (N, odim, T // 4, ((idim - 1) // 2 - 1) // 2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, T // 4, odim) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where approximates T' = T//4. + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >= 4, idim >= 4. + odim: + Output dim. The output shape is (N, T // 4, odim) + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=False + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear(block_dims[-1] * (idim // 4), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, T // 4, odim) + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py new file mode 100644 index 000000000..338688564 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py @@ -0,0 +1,25 @@ +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + B, idim, odim = 1, 80, 512 + model = Conv2dSubsampling(idim, odim) + for t in range(4, 50): + x = torch.randn(B, t, idim) + outputs = model(x) + assert outputs.shape == (B, t // 4, odim) + + +def test_vgg_subsampling(): + B, idim, odim = 1, 80, 512 + model = VggSubsampling(idim, odim) + for t in range(4, 50): + x = torch.randn(B, t, idim) + outputs = model(x) + assert outputs.shape == (B, t // 4, odim) + + +if __name__ == "__main__": + test_conv2d_subsampling() + test_vgg_subsampling() From fe43c1349ee02687c07e6284acd56ccc373fec1e Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 2 Apr 2022 21:14:24 +0800 Subject: [PATCH 174/234] First upload emformer_pruned_transducer_stateless/emformer.py, modified from torchaudio. --- .../emformer.py | 1422 +++++++++++++++++ 1 file changed, 1422 insertions(+) create mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py new file mode 100644 index 000000000..88b1a06fb --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -0,0 +1,1422 @@ +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from icefall.utils import make_pad_mask +from encoder_interface import EncoderInterface +from subsampling import Conv2dSubsampling, VggSubsampling + + +def _gen_padding_mask( + utterance: torch.Tensor, + right_context: torch.Tensor, + lengths: torch.Tensor, + mems: torch.Tensor, + left_context_key: Optional[torch.Tensor] = None, +) -> Optional[torch.Tensor]: + """Generate padding mask according to the length of the tensors + contained in the key. + + Args: + utterance: (U, B, D) + right_context: (R, B, D) + lengths: (B,) + mems: (M, B, D) + left_context_key: (L, B, D) + B is the batch size, D is the feature dimension, + U is the length of the utterance, + R is the length of the right context block, + M is the length of the memory block, + L is the length of the left context block + + Returns: + padding_mask: + Padding mask for the concatenated key tensor + [mems, right_context, left_context, utterance], + sharing for all queries, with shape of (M + R + L + U, B) + """ + assert utterance.size(0) == torch.max(lengths) + B = utterance.size(1) + M = mems.size(0) + R = right_context.size(0) + L = left_context_key.size(0) if left_context_key is not None else 0 + if B == 1: + # TODO: for infer mode? + padding_mask = None + else: + lengths_concat = M + R + L + lengths + padding_mask = make_pad_mask(lengths_concat) + return padding_mask + + +def _get_activation_module(activation: str) -> nn.Module: + if activation == "relu": + return nn.ReLU() + elif activation == "gelu": + return nn.GELU() + elif activation == "silu": + return nn.SiLU() + else: + raise ValueError(f"Unsupported activation {activation}") + + +def _get_weight_init_gains( + weight_init_scale_strategy: Optional[str], + num_layers: int +) -> List[Optional[float]]: + if weight_init_scale_strategy is None: + return [None for _ in range(num_layers)] + elif weight_init_scale_strategy == "depthwise": + return [1.0 / math.sqrt(layer_idx + 1) + for layer_idx in range(num_layers)] + elif weight_init_scale_strategy == "constant": + return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] + else: + raise ValueError(f"Unsupported weight_init_scale_strategy value" + f"{weight_init_scale_strategy}") + + +def _gen_attention_mask_block( + col_widths: List[int], + col_mask: List[bool], + num_rows: int, + device: torch.device +) -> torch.Tensor: + assert len(col_widths) == len(col_mask), ( + "Length of col_widths must match that of col_mask") + + mask_block = [ + torch.ones(num_rows, col_width, device=device) + if is_ones_col + else torch.zeros(num_rows, col_width, device=device) + for col_width, is_ones_col in zip(col_widths, col_mask) + ] + return torch.cat(mask_block, dim=1) + + +def length_down_sampling(length): + # Caution: We assume the subsampling factor is 4! + return ((length - 1) // 2 - 1) // 2 + + +class EmformerAttention(nn.Module): + r"""Emformer layer attention module. + + Args: + embed_dim (int): + Embedding dimension. + nhead (int): + Number of attention heads in each Emformer layer. + dropout (float, optional): + Dropout probability. (Default: 0.0) + weight_init_gain (float or None, optional): + Scale factor to apply when initializing attention + module parameters. (Default: ``None``) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + embed_dim: int, + nhead: int, + dropout: float = 0.0, + weight_init_gain: Optional[float] = None, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + if embed_dim % nhead != 0: + raise ValueError( + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." + ) + + self.embed_dim = embed_dim + self.nhead = nhead + self.dropout = dropout + self.tanh_on_mem = tanh_on_mem + self.negative_inf = negative_inf + + self.scaling = (self.embed_dim // self.nhead) ** -0.5 + + self.emb_to_key_value = nn.Linear( + embed_dim, 2 * embed_dim, bias=True + ) + self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + if weight_init_gain: + nn.init.xavier_uniform_( + self.emb_to_key_value.weight, gain=weight_init_gain + ) + nn.init.xavier_uniform_( + self.emb_to_query.weight, gain=weight_init_gain + ) + + def _gen_attention_probs( + self, + attention_weights: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + """ Given the entire attention weights, mask out unecessary connections + and optionally with padding positions, to obtain underlying chunk-wise + attention probabilities. + + B: batch size; + Q: length of query; + KV: length of key and value. + + Args: + attention_weights (torch.Tensor): + Attention weights computed on the entire concatenated tensor + with shape (B * nhead, Q, KV). + attention_mask (torch.Tensor): + Mask tensor where chunk-wise connections are filled with `False`, + and other unnecessary connections are filled with `True`, + with shape (Q, KV). + padding_mask (torch.Tensor, optional): + Mask tensor where the padding positions are fill with `True`, + and other positions are filled with `False`, with shapa `(B, KV)`. + + Returns: + A tensor of shape (B * nhead, Q, KV). + """ + attention_weights_float = attention_weights.float() + attention_weights_float = attention_weights_float.masked_fill( + attention_mask.unsqueeze(0), self.negative_inf + ) + if padding_mask is not None: + Q = attention_weights.size(1) + B = attention_weights.size(0) // self.nhead + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + self.negative_inf + ) + attention_weights_float = attention_weights_float.view( + B * self.nhead, Q, -1 + ) + + attention_probs = nn.functional.softmax( + attention_weights_float, dim=-1 + ).type_as(attention_weights) + attention_probs = nn.functional.dropout( + attention_probs, + p=float(self.dropout), + training=self.training + ) + return attention_probs + + def _forward_impl( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + left_context_key: Optional[torch.Tensor] = None, + left_context_val: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ Underlying chunk-wise attention implementation. + + L: length of left_context; + S: length of summary; + M: length of memory; + Q: length of attention query; + KV: length of attention key and value. + + 1) Concat right_context, utterance, summary, + and compute query tensor with length Q = R + U + S. + 2) Concat memory, right_context, utterance, + and compute key, value tensors with length KV = M + R + U; + optionally with left_context_key and left_context_val (inference mode) + then KV = M + R + L + U. + 3) Compute entire attention scores with query, key, and value, + then apply attention_mask to get underlying chunk-wise attention scores. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + summary (torch.Tensor): + Summary elements, with shape (S, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying attention, with shape (Q, KV). + left_context_key (torch,Tensor, optional): + Cached attention key of left context from preceding computation, + with shape (L, B, D). + left_context_val (torch.Tensor, optional): + Cached attention value of left context from preceding computation, + with shape (L, B, D). + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (S, B, D). + - attention key, with shape (KV, B, D). + - attention value, with shape (KV, B, D). + """ + B = utterance.size(1) + + # Compute query with [right context, utterance, summary]. + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) + # Compute key and value with [mems, right context, utterance]. + key, value = self.emb_to_key_value( + torch.cat([memory, right_context, utterance]) + ).chunk(chunks=2, dim=2) + + if left_context_key is not None and left_context_val is not None: + # Now compute key and value with + # [mems, right context, left context, uttrance] + M = memory.size(0) + R = right_context.size(0) + key = torch.cat([key[:M + R], left_context_key, key[M + R:]]) + value = torch.cat([value[:M + R], left_context_val, value[M + R:]]) + + # Compute attention weights from query, key, and value. + reshaped_query, reshaped_key, reshaped_value = [ + tensor.contiguous().view( + -1, B * self.nhead, self.embed_dim // self.nhead + ).transpose(0, 1) for tensor in [query, key, value] + ] + attention_weights = torch.bmm( + reshaped_query * self.scaling, reshaped_key.transpose(1, 2) + ) + + # Compute padding mask + if B == 1: + padding_mask = None + else: + KV = key.size(0) + U = utterance.size(0) + padding_mask = make_pad_mask(KV - U + lengths) + + # Compute attention probabilities. + attention_probs = self._gen_attention_probs( + attention_weights, attention_mask, padding_mask + ) + + # Compute attention. + attention = torch.bmm(attention_probs, reshaped_value) + Q = query.size(0) + assert attention.shape == ( + B * self.nhead, Q, self.embed_dim // self.nhead, + ) + attention = attention.transpose(0, 1).contiguous().view( + Q, B, self.embed_dim + ) + + # Apply output projection. + outputs = self.out_proj(attention) + + S = summary.size(0) + output_right_context_utterance = outputs[:-S] + output_memory = outputs[-S:] + if self.tanh_on_mem: + output_memory = torch.tanh(output_memory) + else: + output_memory = torch.clamp(output_memory, min=-10, max=10) + + return output_right_context_utterance, output_memory, key, value + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO: Modify docs. + """Forward pass for training. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + S: length of summary; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + summary (torch.Tensor): + Summary elements, with shape (S, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying chunk-wise attention, + with shape (Q, KV). + + Returns: + A tuple containing 2 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (M, B, D), where M = S - 1. + """ + output_right_context_utterance, output_memory, _, _ = \ + self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask + ) + return output_right_context_utterance, output_memory[:-1] + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + left_context_key: torch.Tensor, + left_context_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + L: length of left_context; + S: length of summary; + M: length of memory; + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + summary (torch.Tensor): + Summary elements, with shape (S, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + left_context_key (torch,Tensor): + Cached attention key of left context from preceding computation, + with shape (L, B, D). + left_context_val (torch.Tensor): + Cached attention value of left context from preceding computation, + with shape (L, B, D). + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (S, B, D). + - attention key of left context and utterance, which would be cached + for next computation, with shape (L + U, B, D). + - attention value of left context and utterance, which would be + cached for next computation, with shape (L + U, B, D). + """ + # query: [right context, utterance, summary] + Q = right_context.size(0) + utterance.size(0) + summary.size(0) + # key, value: [memory, right context, left context, uttrance] + KV = memory.size(0) + right_context.size(0) + \ + left_context_key.size(0) + utterance.size(0) + attention_mask = torch.zeros( + Q, KV + ).to(dtype=torch.bool, device=utterance.device) + # Disallow attention bettween the summary vector with the memory bank + attention_mask[-1, :memory.size(0)] = True + output_right_context_utterance, output_memory, key, value = \ + self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + return ( + output_right_context_utterance, + output_memory, + key[memory.size(0) + right_context.size(0):], + value[memory.size(0) + right_context.size(0):], + ) + + +class EmformerLayer(nn.Module): + """Emformer layer that constitutes Emformer. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads. + dim_feedforward (int): + Hidden layer dimension of feedforward network. + segment_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (Default: 0.0) + activation (str, optional): + Activation function to use in feedforward network. + Must be one of ("relu", "gelu", "silu"). (Default: "relu") + left_context_length (int, optional): + Length of left context. (Default: 0) + max_memory_size (int, optional): + Maximum number of memory elements to use. (Default: 0) + weight_init_gain (float or None, optional): + Scale factor to apply when initializing attention module parameters. + (Default: ``None``) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int, + segment_length: int, + dropout: float = 0.0, + activation: str = "relu", + left_context_length: int = 0, + max_memory_size: int = 0, + weight_init_gain: Optional[float] = None, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.attention = EmformerAttention( + d_model=d_model, + nhead=nhead, + dropout=dropout, + weight_init_gain=weight_init_gain, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + self.dropout = nn.Dropout(dropout) + self.summary_op = nn.AvgPool1d( + kernel_size=segment_length, stride=segment_length, ceil_mode=True + ) + + activation_module = _get_activation_module(activation) + self.pos_ff = nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear(d_model, dim_feedforward), + activation_module, + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + nn.Dropout(dropout), + ) + self.layer_norm_input = nn.LayerNorm(d_model) + self.layer_norm_output = nn.LayerNorm(d_model) + + self.left_context_length = left_context_length + self.segment_length = segment_length + self.max_memory_size = max_memory_size + self.d_model = d_model + + self.use_memory = max_memory_size > 0 + + def _init_state( + self, + batch_size: int, + device: Optional[torch.device] + ) -> List[torch.Tensor]: + """Initialize states with zeros.""" + empty_memory = torch.zeros( + self.max_memory_size, batch_size, self.d_model, device=device + ) + left_context_key = torch.zeros( + self.left_context_length, batch_size, self.d_model, device=device + ) + left_context_val = torch.zeros( + self.left_context_length, batch_size, self.d_model, device=device + ) + past_length = torch.zeros( + 1, batch_size, dtype=torch.int32, device=device + ) + return [empty_memory, left_context_key, left_context_val, past_length] + + def _unpack_state( + self, + state: List[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Unpack cached states including: + 1) output memory from previous chunks in the lower layer; + 2) attention key and value of left context from proceeding chunk's + computation. + """ + past_length = state[3][0][0].item() + past_left_context_length = min(self.left_context_length, past_length) + past_memory_length = min( + self.max_memory_size, math.ceil(past_length / self.segment_length) + ) + pre_memory = state[0][-past_memory_length:] + left_context_key = state[1][-past_left_context_length:] + left_context_val = state[2][-past_left_context_length:] + return pre_memory, left_context_key, left_context_val + + def _pack_state( + self, + next_key: torch.Tensor, + next_val: torch.Tensor, + update_length: int, + memory: torch.Tensor, + state: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Pack updated states including: + 1) output memory of current chunk in the lower layer; + 2) attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + 3) length of current chunk. + """ + new_memory = torch.cat([state[0], memory]) + new_key = torch.cat([state[1], next_key]) + new_val = torch.cat([state[2], next_val]) + state[0] = new_memory[-self.max_memory_size:] + state[1] = new_key[-self.left_context_length:] + state[2] = new_val[-self.left_context_length:] + state[3] = state[3] + update_length + return state + + def _apply_pre_attention_layer_norm( + self, utterance: torch.Tensor, right_context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply layer normalization before attention. """ + layer_norm_input = self.layer_norm_input( + torch.cat([right_context, utterance]) + ) + layer_norm_utterance = layer_norm_input[right_context.size(0):] + layer_norm_right_context = layer_norm_input[:right_context.size(0)] + return layer_norm_utterance, layer_norm_right_context + + def _apply_post_attention_ffn_layer_norm( + self, + output_right_context_utterance: torch.Tensor, + utterance: torch.Tensor, + right_context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply feed forward and layer normalization after attention.""" + # Apply residual connection between input and attention output. + result = self.dropout(output_right_context_utterance) + \ + torch.cat([right_context, utterance]) + # Apply feedforward module and residual connection. + result = self.pos_ff(result) + result + # Apply layer normalization for output. + result = self.layer_norm_output(result) + + output_utterance = result[right_context.size(0):] + output_right_context = result[:right_context.size(0)] + return output_utterance, output_right_context + + def _apply_attention_forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply attention in non-infer mode. """ + if attention_mask is None: + raise ValueError( + "attention_mask must be not None in non-infer mode. " + ) + + if self.use_memory: + summary = self.summary_op( + utterance.permute(1, 2, 0) + ).permute(2, 0, 1) + else: + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + output_right_context_utterance, output_memory = self.attention( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=memory, + attention_mask=attention_mask, + ) + return output_right_context_utterance, output_memory + + def _apply_attention_infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Apply attention in infer mode. + 1) Unpack cached states including: + - memory from previous chunks in the lower layer; + - attention key and value of left context from proceeding + chunk's compuation; + 2) Apply attention computation; + 3) Pack updated states including: + - output memory of current chunk in the lower layer; + - attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + - length of current chunk. + """ + if state is None: + state = self._init_state(utterance.size(1), device=utterance.device) + pre_memory, left_context_key, left_context_val = \ + self._unpack_state(state) + if self.use_memory: + summary = self.summary_op( + utterance.permute(1, 2, 0) + ).permute(2, 0, 1) + summary = summary[:1] + else: + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + output_right_context_utterance, output_memory, next_key, next_val = \ + self.attention.infer( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + state = self._pack_state( + next_key, next_val, utterance.size(0), memory, state + ) + return output_right_context_utterance, output_memory, state + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + 1) Apply layer normalization on input utterance and right context + before attention; + 2) Apply attention module, compute updated utterance, right context, + and memory; + 3) Apply feed forward module and layer normalization on output utterance + and right context. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying attention module. + + Returns: + A tuple containing 3 tensors: + - output utterance, with shape (U, B, D). + - output right context, with shape (R, B, D). + - output memory, with shape (M, B, D). + """ + ( + layer_norm_utterance, + layer_norm_right_context, + ) = self._apply_pre_attention_layer_norm(utterance, right_context) + output_right_context_utterance, output_memory = \ + self._apply_attention_forward( + layer_norm_utterance, + lengths, + layer_norm_right_context, + memory, + attention_mask, + ) + output_utterance, output_right_context = \ + self._apply_post_attention_ffn_layer_norm( + output_right_context_utterance, + utterance, + right_context + ) + return output_utterance, output_right_context, output_memory + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: + """Forward pass for inference. + + 1) Apply layer normalization on input utterance and right context + before attention; + 2) Apply attention module with cached state, compute updated utterance, + right context, and memory, and update state; + 3) Apply feed forward module and layer normalization on output utterance + and right context. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + state (List[torch.Tensor], optional): + List of tensors representing layer internal state generated in + preceding computation. (default=None) + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output memory, with shape (M, B, D); + - output state. + """ + ( + layer_norm_utterance, + layer_norm_right_context, + ) = self._apply_pre_attention_layer_norm(utterance, right_context) + output_right_context_utterance, output_memory, output_state = \ + self._apply_attention_infer( + layer_norm_utterance, + lengths, + layer_norm_right_context, + memory, + state + ) + output_utterance, output_right_context = \ + self._apply_post_attention_ffn_layer_norm( + output_right_context_utterance, + utterance, + right_context + ) + return ( + output_utterance, + output_right_context, + output_memory, + output_state + ) + + +class EmformerEncoder(nn.Module): + """Implements the Emformer architecture introduced in + *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency + Streaming Speech Recognition* + [:footcite:`shi2021emformer`]. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads in each emformer layer. + dim_feedforward (int): + Hidden layer dimension of each emformer layer's feedforward network. + num_encoder_layers (int): + Number of emformer layers to instantiate. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (default: 0.0) + activation (str, optional): + Activation function to use in each emformer layer's feedforward network. + Must be one of ("relu", "gelu", "silu"). (default: "relu") + left_context_length (int, optional): + Length of left context. (default: 0) + right_context_length (int, optional): + Length of right context. (default: 0) + max_memory_size (int, optional): + Maximum number of memory elements to use. (default: 0) + weight_init_scale_strategy (str, optional): + Per-layer weight initialization scaling strategy. must be one of + ("depthwise", "constant", ``none``). (default: "depthwise") + tanh_on_mem (bool, optional): + If ``true``, applies tanh to memory elements. (default: ``false``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (default: -1e8) + + examples: + >>> emformer = emformer(512, 8, 2048, 20, 4, right_context_length=1) + >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim + >>> lengths = torch.randint(1, 200, (128,)) # batch + >>> output = emformer(input, lengths) + >>> input = torch.rand(128, 5, 512) + >>> lengths = torch.ones(128) * 5 + >>> output, lengths, states = emformer.infer(input, lengths, None) + """ + + def __init__( + self, + chunk_length: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + activation: str = "relu", + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + weight_init_scale_strategy: str = "depthwise", + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.use_memory = max_memory_size > 0 + self.memory_op = nn.AvgPool1d( + kernel_size=chunk_length, + stride=chunk_length, + ceil_mode=True, + ) + + weight_init_gains = _get_weight_init_gains( + weight_init_scale_strategy, num_encoder_layers + ) + self.emformer_layers = nn.ModuleList( + [ + EmformerLayer( + d_model, + nhead, + dim_feedforward, + chunk_length, + dropout=dropout, + activation=activation, + left_context_length=left_context_length, + max_memory_size=max_memory_size, + weight_init_gain=weight_init_gains[layer_idx], + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + for layer_idx in range(num_encoder_layers) + ] + ) + + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.max_memory_size = max_memory_size + + def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: + """Hard copy each chunk's right context and concat them. """ + T = x.shape[0] + num_segs = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) + right_context_blocks = [] + for seg_idx in range(num_segs - 1): + start = (seg_idx + 1) * self.chunk_length + end = start + self.right_context_length + right_context_blocks.append(x[start:end]) + right_context_blocks.append(x[-self.right_context_length:]) + return torch.cat(right_context_blocks) + + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: + """Calculate column widths (key, value) in attention mask for the + chunk_idx chunk.""" + num_chunks = math.ceil(U / self.chunk_length) + rc = self.right_context_length + lc = self.left_context_length + rc_start = chunk_idx * rc + rc_end = rc_start + rc + chunk_start = max(chunk_idx * self.chunk_length - lc, 0) + chunk_end = min((chunk_idx + 1) * self.chunk_length, U) + R = rc * num_chunks + + if self.use_memory: + m_start = max(chunk_idx - self.max_memory_size, 0) + M = num_chunks - 1 + col_widths = [ + m_start, # before memory + chunk_idx - m_start, # memory + M - chunk_idx, # after memory + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + else: + col_widths = [ + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + + return col_widths + + def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor: + """Generate attention mask for underlying chunk-wise attention + computation, where chunk-wise connections are filled with `False`, + and other unnecessary connections beyond chunk are filled with `True`. + + R: length of right_context; + U: length of utterance; + S: length of summary; + M: length of memory; + Q: length of attention query; + KV: length of attention key and value. + + The shape of attention mask is (Q, KV). + If self.use_memory is `True`: + query = [right_context, utterance, summary]; + key, value = [memory, right_context, utterance]; + Q = R + U + S, KV = M + R + U. + Otherwise: + query = [right_context, utterance] + key, value = [right_context, utterance] + Q = R + U, KV = R + U. + """ + U = utterance.size(0) + num_chunks = math.ceil(U / self.chunk_length) + + right_context_mask = [] + utterance_mask = [] + summary_mask = [] + + if self.use_memory: + num_cols = 9 + # right context and utterance both attend to memory, right context, + # utterance + right_context_utterance_cols_mask = \ + [idx in [1, 4, 7] for idx in range(num_cols)] + # summary attends to right context, utterance + summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] + masks_to_concat = [right_context_mask, utterance_mask, summary_mask] + else: + num_cols = 6 + # right context and utterance both attend to right context and + # utterance + right_context_utterance_cols_mask = \ + [idx in [1, 4] for idx in range(num_cols)] + summary_cols_mask = None + masks_to_concat = [right_context_mask, utterance_mask] + + for chunk_idx in range(num_chunks): + col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) + + right_context_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + self.right_context_length, + utterance.device + ) + right_context_mask.append(right_context_mask_block) + + utterance_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + min( + self.chunk_length, + U - chunk_idx * self.chunk_length, + ), + utterance.device, + ) + utterance_mask.append(utterance_mask_block) + + if summary_cols_mask is not None: + summary_mask_block = _gen_attention_mask_block( + col_widths, summary_cols_mask, 1, utterance.device + ) + summary_mask.append(summary_mask_block) + + attention_mask = ( + 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) + ).to(torch.bool) + return attention_mask + + def forward( + self, x: torch.Tensor, lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x. + It is the true lengths without containing the right_context. + + Returns: + (Tensor, Tensor): + - output utterance frames, with shape (U, B, D). + - output lengths, with shape (B,) and i-th element representing + number of valid frames for i-th batch element in output frames. + """ + assert x.size(0) == torch.max(lengths).item() + \ + self.right_context_length + right_context = self._gen_right_context(x) + utterance = x[:-self.right_context_length] + attention_mask = self._gen_attention_mask(utterance) + memory = ( + self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + if self.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + output = utterance + for layer in self.emformer_layers: + output, right_context, memory = \ + layer(output, lengths, right_context, memory, attention_mask) + + return output, lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + """Forward pass for streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (chunk_length + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x. + It contains the right_context. + states (List[List[torch.Tensor]], optional): + Cached states from proceeding chunk's computation, where each + element (List[torch.Tensor]) corresponding to each emformer layer. + (default: None) + + Returns: + (Tensor, Tensor, List[List[torch.Tensor]]): + - output utterance frames, with shape (U, B, D). + - output lengths, with shape (B,) and i-th element representing + number of valid frames for i-th batch element in output frames. + - updated states from current chunk's computation. + """ + assert x.size(0) == self.chunk_length + self.right_context_length, ( + "Per configured chunk_length and right_context_length, " + f"expected size of {self.chunk_length + self.right_context_length} " + f"for dimension 1 of x, but got {x.size(1)}." + ) + right_context = x[-self.right_context_length:] + utterance = x[:-self.right_context_length] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + memory = ( + self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + if self.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + output = utterance + output_states: List[List[torch.Tensor]] = [] + for layer_idx, layer in enumerate(self.emformer_layers): + output, right_context, output_state, memory = layer.infer( + output, + output_lengths, + right_context, + None if states is None else states[layer_idx], + memory, + ) + output_states.append(output_state) + + return output, output_lengths, output_states + + +class Emformer(EncoderInterface): + def __init__( + self, + num_features: int, + output_dim: int, + chunk_length: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + vgg_frontend: bool = False, + activation: str = "relu", + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + weight_init_scale_strategy: str = "depthwise", + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.subsampling_factor = subsampling_factor + self.right_context_length = right_context_length + if subsampling_factor != 4: + raise NotImplementedError( + "Support only 'subsampling_factor=4'." + ) + if chunk_length % 4 != 0: + raise NotImplementedError( + "chunk_length must be a mutiple of 4." + ) + if left_context_length != 0 and left_context_length % 4 != 0: + raise NotImplementedError( + "left_context_length must be a mutiple of 4." + ) + if right_context_length != 0 and right_context_length % 4 != 0: + raise NotImplementedError( + "right_context_length must be a mutiple of 4." + ) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + self.encoder = EmformerEncoder( + chunk_length // 4, + d_model, + nhead, + dim_feedforward, + num_encoder_layers, + dropout, + activation, + left_context_length=left_context_length // 4, + right_context_length=right_context_length // 4, + max_memory_size=max_memory_size, + weight_init_scale_strategy=weight_init_scale_strategy, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) + ) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: feature dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, U + right_context_length, D). + x_lens (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x. + It is the true lengths without containing the right_context. + + Returns: + (Tensor, Tensor): + - output logits, with shape (B, U // 4, D). + - logits lengths, with shape (B,) and i-th element representing + number of valid frames for i-th batch element in output frames. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = x_lens // 4 + assert x.size(0) == \ + lengths.max().item() + self.right_context_length // 4 + + output, output_lengths = self.encoder(x, lengths) # (T, N, C) + + logits = self.encoder_output_layer(output) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, output_lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + """Forward pass for streaming inference. + + B: batch size; + D: feature dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, U + right_context_length, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x. + It is the true lengths without containing the right_context. + states (List[List[torch.Tensor]], optional): + Cached states from proceeding chunk's computation, where each + element (List[torch.Tensor]) corresponding to each emformer layer. + (default: None) + Returns: + (Tensor, Tensor): + - output logits, with shape (B, U // 4, D). + - logits lengths, with shape (B,) and i-th element representing + number of valid frames for i-th batch element in output frames. + - updated states from current chunk's computation. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + lengths = x_lens // 4 + assert x.size(0) == lengths.max().item() + output, output_lengths, output_states = \ + self.encoder.infer(x, lengths, states) # (T, N, C) + + logits = self.encoder_output_layer(output) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, output_lengths, output_states + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + From 9423b3899fccb321f974531f0b72a88a1518abf8 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 4 Apr 2022 22:16:46 +0800 Subject: [PATCH 175/234] Update emformer_pruned_transducer_stateless/emformer.py and upload emformer_pruned_transducer_stateless/test_emformer.py. --- .../emformer.py | 176 ++++----- .../test_emformer.py | 345 ++++++++++++++++++ 2 files changed, 408 insertions(+), 113 deletions(-) create mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 88b1a06fb..32498a2c1 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -9,48 +9,6 @@ from encoder_interface import EncoderInterface from subsampling import Conv2dSubsampling, VggSubsampling -def _gen_padding_mask( - utterance: torch.Tensor, - right_context: torch.Tensor, - lengths: torch.Tensor, - mems: torch.Tensor, - left_context_key: Optional[torch.Tensor] = None, -) -> Optional[torch.Tensor]: - """Generate padding mask according to the length of the tensors - contained in the key. - - Args: - utterance: (U, B, D) - right_context: (R, B, D) - lengths: (B,) - mems: (M, B, D) - left_context_key: (L, B, D) - B is the batch size, D is the feature dimension, - U is the length of the utterance, - R is the length of the right context block, - M is the length of the memory block, - L is the length of the left context block - - Returns: - padding_mask: - Padding mask for the concatenated key tensor - [mems, right_context, left_context, utterance], - sharing for all queries, with shape of (M + R + L + U, B) - """ - assert utterance.size(0) == torch.max(lengths) - B = utterance.size(1) - M = mems.size(0) - R = right_context.size(0) - L = left_context_key.size(0) if left_context_key is not None else 0 - if B == 1: - # TODO: for infer mode? - padding_mask = None - else: - lengths_concat = M + R + L + lengths - padding_mask = make_pad_mask(lengths_concat) - return padding_mask - - def _get_activation_module(activation: str) -> nn.Module: if activation == "relu": return nn.ReLU() @@ -96,11 +54,6 @@ def _gen_attention_mask_block( return torch.cat(mask_block, dim=1) -def length_down_sampling(length): - # Caution: We assume the subsampling factor is 4! - return ((length - 1) // 2 - 1) // 2 - - class EmformerAttention(nn.Module): r"""Emformer layer attention module. @@ -239,7 +192,7 @@ class EmformerAttention(nn.Module): and compute query tensor with length Q = R + U + S. 2) Concat memory, right_context, utterance, and compute key, value tensors with length KV = M + R + U; - optionally with left_context_key and left_context_val (inference mode) + optionally with left_context_key and left_context_val (inference mode), then KV = M + R + L + U. 3) Compute entire attention scores with query, key, and value, then apply attention_mask to get underlying chunk-wise attention scores. @@ -284,7 +237,7 @@ class EmformerAttention(nn.Module): ).chunk(chunks=2, dim=2) if left_context_key is not None and left_context_val is not None: - # Now compute key and value with + # This is for inference mode. Now compute key and value with # [mems, right context, left context, uttrance] M = memory.size(0) R = right_context.size(0) @@ -328,8 +281,8 @@ class EmformerAttention(nn.Module): outputs = self.out_proj(attention) S = summary.size(0) - output_right_context_utterance = outputs[:-S] - output_memory = outputs[-S:] + output_right_context_utterance = outputs[:Q - S] + output_memory = outputs[Q - S:] if self.tanh_on_mem: output_memory = torch.tanh(output_memory) else: @@ -370,12 +323,12 @@ class EmformerAttention(nn.Module): Memory elements, with shape (M, B, D). attention_mask (torch.Tensor): Attention mask for underlying chunk-wise attention, - with shape (Q, KV). + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. Returns: A tuple containing 2 tensors: - output of right context and utterance, with shape (R + U, B, D). - - memory output, with shape (M, B, D), where M = S - 1. + - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ output_right_context_utterance, output_memory, _, _ = \ self._forward_impl( @@ -418,7 +371,7 @@ class EmformerAttention(nn.Module): right_context (torch.Tensor): Right context frames, with shape (R, B, D). summary (torch.Tensor): - Summary elements, with shape (S, B, D). + Summary element, with shape (1, B, D), or empty. memory (torch.Tensor): Memory elements, with shape (M, B, D). left_context_key (torch,Tensor): @@ -431,7 +384,7 @@ class EmformerAttention(nn.Module): Returns: A tuple containing 4 tensors: - output of right context and utterance, with shape (R + U, B, D). - - memory output, with shape (S, B, D). + - memory output, with shape (1, B, D) or (0, B, D). - attention key of left context and utterance, which would be cached for next computation, with shape (L + U, B, D). - attention value of left context and utterance, which would be @@ -476,7 +429,7 @@ class EmformerLayer(nn.Module): Number of attention heads. dim_feedforward (int): Hidden layer dimension of feedforward network. - segment_length (int): + chunk_length (int): Length of each input segment. dropout (float, optional): Dropout probability. (Default: 0.0) @@ -501,7 +454,7 @@ class EmformerLayer(nn.Module): d_model: int, nhead: int, dim_feedforward: int, - segment_length: int, + chunk_length: int, dropout: float = 0.0, activation: str = "relu", left_context_length: int = 0, @@ -513,7 +466,7 @@ class EmformerLayer(nn.Module): super().__init__() self.attention = EmformerAttention( - d_model=d_model, + embed_dim=d_model, nhead=nhead, dropout=dropout, weight_init_gain=weight_init_gain, @@ -522,7 +475,7 @@ class EmformerLayer(nn.Module): ) self.dropout = nn.Dropout(dropout) self.summary_op = nn.AvgPool1d( - kernel_size=segment_length, stride=segment_length, ceil_mode=True + kernel_size=chunk_length, stride=chunk_length, ceil_mode=True ) activation_module = _get_activation_module(activation) @@ -538,7 +491,7 @@ class EmformerLayer(nn.Module): self.layer_norm_output = nn.LayerNorm(d_model) self.left_context_length = left_context_length - self.segment_length = segment_length + self.chunk_length = chunk_length self.max_memory_size = max_memory_size self.d_model = d_model @@ -576,11 +529,13 @@ class EmformerLayer(nn.Module): past_length = state[3][0][0].item() past_left_context_length = min(self.left_context_length, past_length) past_memory_length = min( - self.max_memory_size, math.ceil(past_length / self.segment_length) + self.max_memory_size, math.ceil(past_length / self.chunk_length) ) - pre_memory = state[0][-past_memory_length:] - left_context_key = state[1][-past_left_context_length:] - left_context_val = state[2][-past_left_context_length:] + pre_memory = state[0][self.max_memory_size - past_memory_length:] + left_context_key = \ + state[1][self.left_context_length - past_left_context_length:] + left_context_val = \ + state[2][self.left_context_length - past_left_context_length:] return pre_memory, left_context_key, left_context_val def _pack_state( @@ -600,9 +555,9 @@ class EmformerLayer(nn.Module): new_memory = torch.cat([state[0], memory]) new_key = torch.cat([state[1], next_key]) new_val = torch.cat([state[2], next_val]) - state[0] = new_memory[-self.max_memory_size:] - state[1] = new_key[-self.left_context_length:] - state[2] = new_val[-self.left_context_length:] + state[0] = new_memory[new_memory.size(0) - self.max_memory_size:] + state[1] = new_key[new_key.size(0) - self.left_context_length:] + state[2] = new_val[new_val.size(0) - self.left_context_length:] state[3] = state[3] + update_length return state @@ -749,7 +704,8 @@ class EmformerLayer(nn.Module): memory (torch.Tensor): Memory elements, with shape (M, B, D). attention_mask (torch.Tensor): - Attention mask for underlying attention module. + Attention mask for underlying attention module, + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. Returns: A tuple containing 3 tensors: @@ -819,7 +775,7 @@ class EmformerLayer(nn.Module): (Tensor, Tensor, List[torch.Tensor], Tensor): - output utterance, with shape (U, B, D); - output right_context, with shape (R, B, D); - - output memory, with shape (M, B, D); + - output memory, with shape (1, B, D) or (0, B, D). - output state. """ ( @@ -883,15 +839,6 @@ class EmformerEncoder(nn.Module): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): Value to use for negative infinity in attention weights. (default: -1e8) - - examples: - >>> emformer = emformer(512, 8, 2048, 20, 4, right_context_length=1) - >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim - >>> lengths = torch.randint(1, 200, (128,)) # batch - >>> output = emformer(input, lengths) - >>> input = torch.rand(128, 5, 512) - >>> lengths = torch.ones(128) * 5 - >>> output, lengths, states = emformer.infer(input, lengths, None) """ def __init__( @@ -913,7 +860,7 @@ class EmformerEncoder(nn.Module): super().__init__() self.use_memory = max_memory_size > 0 - self.memory_op = nn.AvgPool1d( + self.init_memory_op = nn.AvgPool1d( kernel_size=chunk_length, stride=chunk_length, ceil_mode=True, @@ -957,7 +904,7 @@ class EmformerEncoder(nn.Module): start = (seg_idx + 1) * self.chunk_length end = start + self.right_context_length right_context_blocks.append(x[start:end]) - right_context_blocks.append(x[-self.right_context_length:]) + right_context_blocks.append(x[T - self.right_context_length:]) return torch.cat(right_context_blocks) def _gen_attention_mask_col_widths( @@ -1095,31 +1042,34 @@ class EmformerEncoder(nn.Module): with shape (U + right_context_length, B, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It is the true lengths without containing the right_context. + utterance frames for i-th batch element in x, which contains the + right_context at the end. Returns: - (Tensor, Tensor): + A tuple of 2 tensors: - output utterance frames, with shape (U, B, D). - - output lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - output_lengths, with shape (B,), without containing the + right_context at the end. """ - assert x.size(0) == torch.max(lengths).item() + \ - self.right_context_length + # assert x.size(0) == torch.max(lengths).item() right_context = self._gen_right_context(x) - utterance = x[:-self.right_context_length] + utterance = x[:x.size(0) - self.right_context_length] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + self.init_memory_op( + utterance.permute(1, 2, 0) + ).permute(2, 0, 1)[:-1] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) output = utterance for layer in self.emformer_layers: - output, right_context, memory = \ - layer(output, lengths, right_context, memory, attention_mask) + output, right_context, memory = layer( + output, output_lengths, right_context, memory, attention_mask + ) - return output, lengths + return output, output_lengths @torch.jit.export def infer( @@ -1137,11 +1087,11 @@ class EmformerEncoder(nn.Module): Args: x (torch.Tensor): Utterance frames right-padded with right context frames, - with shape (chunk_length + right_context_length, B, D). + with shape (U + right_context_length, B, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It contains the right_context. + utterance frames for i-th batch element in x, which contains the + right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponding to each emformer layer. @@ -1150,8 +1100,8 @@ class EmformerEncoder(nn.Module): Returns: (Tensor, Tensor, List[List[torch.Tensor]]): - output utterance frames, with shape (U, B, D). - - output lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - output lengths, with shape (B,), without containing the + right_context at the end. - updated states from current chunk's computation. """ assert x.size(0) == self.chunk_length + self.right_context_length, ( @@ -1159,23 +1109,24 @@ class EmformerEncoder(nn.Module): f"expected size of {self.chunk_length + self.right_context_length} " f"for dimension 1 of x, but got {x.size(1)}." ) - right_context = x[-self.right_context_length:] - utterance = x[:-self.right_context_length] + right_context_start_idx = x.size(0) - self.right_context_length + right_context = x[right_context_start_idx:] + utterance = x[:right_context_start_idx] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) memory = ( - self.memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) output = utterance output_states: List[List[torch.Tensor]] = [] for layer_idx, layer in enumerate(self.emformer_layers): - output, right_context, output_state, memory = layer.infer( + output, right_context, memory, output_state = layer.infer( output, output_lengths, right_context, - None if states is None else states[layer_idx], memory, + None if states is None else states[layer_idx], ) output_states.append(output_state) @@ -1272,24 +1223,23 @@ class Emformer(EncoderInterface): with shape (B, U + right_context_length, D). x_lens (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It is the true lengths without containing the right_context. + utterance frames for i-th batch element in x, containing the + right_context at the end. Returns: (Tensor, Tensor): - output logits, with shape (B, U // 4, D). - - logits lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - logits lengths, with shape (B,), without containing the + right_context at the end. """ + # TODO: x.shape x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! lengths = x_lens // 4 - assert x.size(0) == \ - lengths.max().item() + self.right_context_length // 4 - + assert x.size(0) == lengths.max().item() output, output_lengths = self.encoder(x, lengths) # (T, N, C) logits = self.encoder_output_layer(output) @@ -1316,8 +1266,8 @@ class Emformer(EncoderInterface): with shape (B, U + right_context_length, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid - utterance frames for i-th batch element in x. - It is the true lengths without containing the right_context. + utterance frames for i-th batch element in x, containing the + right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each element (List[torch.Tensor]) corresponding to each emformer layer. @@ -1325,8 +1275,8 @@ class Emformer(EncoderInterface): Returns: (Tensor, Tensor): - output logits, with shape (B, U // 4, D). - - logits lengths, with shape (B,) and i-th element representing - number of valid frames for i-th batch element in output frames. + - logits lengths, with shape (B,), without containing the + right_context at the end. - updated states from current chunk's computation. """ x = self.encoder_embed(x) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py new file mode 100644 index 000000000..ae93a4c8f --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -0,0 +1,345 @@ +import torch + + +def test_emformer_attention_forward(): + from emformer import EmformerAttention + + B, D = 2, 256 + U, R = 12, 2 + chunk_length = 2 + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + + output_right_context_utterance, output_memory = attention( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_attention_infer(): + from emformer import EmformerAttention + + B, D = 2, 256 + R, L = 4, 2 + chunk_length = 2 + U = chunk_length + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S, M = 1, 3 + else: + S, M = 0, 0 + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + left_context_key = torch.randn(L, B, D) + left_context_val = torch.randn(L, B, D) + + output_right_context_utterance, output_memory, next_key, next_val = \ + attention.infer( + utterance, + lengths, + right_context, + summary, + memory, + left_context_key, + left_context_val, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (S, B, D) + assert next_key.shape == (L + U, B, D) + assert next_val.shape == (L + U, B, D) + + +def test_emformer_layer_forward(): + from emformer import EmformerLayer + + B, D = 2, 256 + U, R, L = 12, 2, 5 + chunk_length = 2 + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + layer = EmformerLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + left_context_length=L, + max_memory_size=M, + ) + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + + output_utterance, output_right_context, output_memory = layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_layer_infer(): + from emformer import EmformerLayer + + B, D = 2, 256 + R, L = 2, 5 + chunk_length = 2 + U = chunk_length + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + layer = EmformerLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + left_context_length=L, + max_memory_size=M, + ) + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + state = None + output_utterance, output_right_context, output_memory, output_state = \ + layer.infer( + utterance, + lengths, + right_context, + memory, + state, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + if use_memory: + assert output_memory.shape == (1, B, D) + else: + assert output_memory.shape == (0, B, D) + assert len(output_state) == 4 + assert output_state[0].shape == (M, B, D) + assert output_state[1].shape == (L, B, D) + assert output_state[2].shape == (L, B, D) + assert output_state[3].shape == (1, B) + + +def test_emformer_encoder_forward(): + from emformer import EmformerEncoder + + B, D = 2, 256 + U, R, L = 12, 2, 5 + chunk_length = 2 + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=2, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + ) + + x = torch.randn(U + R, B, D) + lengths = torch.randint(1, U + R + 1, (B,)) + lengths[0] = U + R + + output, output_lengths = encoder(x, lengths) + assert output.shape == (U, B, D) + assert torch.equal( + output_lengths, torch.clamp(lengths - R, min=0) + ) + + +def test_emformer_encoder_infer(): + from emformer import EmformerEncoder + + B, D = 2, 256 + R, L = 2, 5 + chunk_length = 2 + U = chunk_length + num_chunks = 3 + num_encoder_layers = 2 + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + ) + + states = None + for chunk_idx in range(num_chunks): + x = torch.randn(U + R, B, D) + lengths = torch.randint(1, U + R + 1, (B,)) + lengths[0] = U + R + output, output_lengths, states = \ + encoder.infer(x, lengths, states) + assert output.shape == (U, B, D) + assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (L, B, D) + assert state[2].shape == (L, B, D) + assert torch.equal( + state[3], (chunk_idx + 1) * U * torch.ones_like(state[3]) + ) + + +def test_emformer_forward(): + from emformer import Emformer + num_features = 80 + output_dim = 1000 + chunk_length = 16 + L, R = 32, 16 + B, D, U = 2, 256, 48 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + x = torch.randn(B, U + R, num_features) + x_lens = torch.randint(1, U + R + 1, (B,)) + x_lens[0] = U + R + logits, output_lengths = model(x, x_lens) + assert logits.shape == (B, U // 4, output_dim) + assert torch.equal( + output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0) + ) + + +def test_emformer_infer(): + from emformer import Emformer + num_features = 80 + output_dim = 1000 + chunk_length = 16 + U = chunk_length + L, R = 32, 16 + B, D = 2, 256 + num_chunks = 3 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + states = None + for chunk_idx in range(num_chunks): + x = torch.randn(B, U + R, num_features) + x_lens = torch.randint(1, U + R + 1, (B,)) + x_lens[0] = U + R + logits, output_lengths, states = \ + model.infer(x, x_lens, states) + assert logits.shape == (B, U // 4, output_dim) + assert torch.equal( + output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0) + ) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (L // 4, B, D) + assert state[2].shape == (L // 4, B, D) + assert torch.equal( + state[3], + (chunk_idx + 1) * U // 4 * torch.ones_like(state[3]) + ) + + +if __name__ == "__main__": + test_emformer_attention_forward() + test_emformer_attention_infer() + test_emformer_layer_forward() + test_emformer_layer_infer() + test_emformer_encoder_forward() + test_emformer_encoder_infer() + test_emformer_forward() + test_emformer_infer() From a41e93437c608f2061f72796c7260e3d5ff7bc7c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 Apr 2022 12:36:58 +0800 Subject: [PATCH 176/234] Change some defaults in LR-setting rule. --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2b40dda45..a2e0463da 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -60,7 +60,7 @@ class Eve(Optimizer): """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - weight_decay=3e-4, target_rms=0.1): + weight_decay=1e-3, target_rms=0.1): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 3b8f0499f..306a2195b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -161,7 +161,7 @@ def get_parser(): parser.add_argument( "--lr-power", type=float, - default=0.5, + default=0.75, help="Power in LR-setting rule", ) @@ -780,8 +780,7 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), - lr=params.initial_lr, betas=(0.9, 0.98), - eps=1e-9, weight_decay=3e-04, target_rms=0.1) + lr=params.initial_lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) From 61486a0f76d79e941257f87efc7b10188fb48b44 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 6 Apr 2022 13:17:26 +0800 Subject: [PATCH 177/234] Remove initial_speed --- .../ASR/pruned_transducer_stateless2/conformer.py | 8 -------- .../ASR/pruned_transducer_stateless2/decoder.py | 6 ------ .../ASR/pruned_transducer_stateless2/joiner.py | 3 --- 3 files changed, 17 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 4797cce08..94c6aa90c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -956,30 +956,22 @@ class Conv2dSubsampling(nn.Module): assert in_channels >= 7 super().__init__() - # This initial_speed is to slightly slow down the relative speed of - # training during the warmup phase by increasing the magnitude of the - # initial parameter values. The intention is to allow us to - # use a higher lr_factor. - initial_speed = 0.5 self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=1, - initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, stride=2, - initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=2, - initial_speed=initial_speed, ), ActivationBalancer(channel_dim=1), DoubleSwish(), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 3291ad877..c23568ae9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -56,16 +56,10 @@ class Decoder(nn.Module): """ super().__init__() - # This initial_speed is to slightly slow down the relative speed of - # training during the warmup phase by increasing the magnitude of the - # initial parameter values. The intention is to allow us to - # use a higher lr_factor. - initial_speed = 0.5 self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, padding_idx=blank_id, - initial_speed=initial_speed ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 752a5f774..2299a0a8c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -27,9 +27,6 @@ class Joiner(nn.Module): vocab_size: int): super().__init__() - # We don't bother giving the 'initial_speed' arg to the decoder - # submodules, because it does not affect the initial convergence of the - # system (only the simple joiner is involved in that). self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) self.output_linear = ScaledLinear(joiner_dim, vocab_size) From 374eacdd5ca02c92f6f1f6615b9b72c237172eb7 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 7 Apr 2022 21:32:59 +0800 Subject: [PATCH 178/234] First upload emformer_pruned_transducer_stateless recipe, refator emformer codes from torchaudio. --- .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 549 ++++++++++ .../decoder.py | 1 + .../emformer.py | 110 +- .../encoder_interface.py | 1 + .../joiner.py | 1 + .../model.py | 1 + .../noam.py | 104 ++ .../subsampling.py | 167 +-- .../test_emformer.py | 30 +- .../test_subsampling.py | 25 - .../train.py | 998 ++++++++++++++++++ 13 files changed, 1691 insertions(+), 298 deletions(-) create mode 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/asr_datamodule.py create mode 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/beam_search.py create mode 100755 egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py create mode 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/decoder.py create mode 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/encoder_interface.py create mode 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/joiner.py create mode 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/model.py create mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/noam.py mode change 100644 => 120000 egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py delete mode 100644 egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py create mode 100755 egs/librispeech/ASR/emformer_pruned_transducer_stateless/train.py diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/asr_datamodule.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/asr_datamodule.py new file mode 120000 index 000000000..b4e5427e0 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/beam_search.py new file mode 120000 index 000000000..227d2247c --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py new file mode 100755 index 000000000..c40b01dfa --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decode.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decoder.py new file mode 120000 index 000000000..0d5f10dc0 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 32498a2c1..edba2e0b3 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -1,5 +1,6 @@ import math from typing import List, Optional, Tuple +import warnings import torch from torch import nn @@ -1051,7 +1052,6 @@ class EmformerEncoder(nn.Module): - output_lengths, with shape (B,), without containing the right_context at the end. """ - # assert x.size(0) == torch.max(lengths).item() right_context = self._gen_right_context(x) utterance = x[:x.size(0) - self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) @@ -1168,11 +1168,11 @@ class Emformer(EncoderInterface): ) if left_context_length != 0 and left_context_length % 4 != 0: raise NotImplementedError( - "left_context_length must be a mutiple of 4." + "left_context_length must be 0 or a mutiple of 4." ) if right_context_length != 0 and right_context_length % 4 != 0: raise NotImplementedError( - "right_context_length must be a mutiple of 4." + "right_context_length must be 0 or a mutiple of 4." ) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -1185,8 +1185,6 @@ class Emformer(EncoderInterface): else: self.encoder_embed = Conv2dSubsampling(num_features, d_model) - self.encoder_pos = PositionalEncoding(d_model, dropout) - self.encoder = EmformerEncoder( chunk_length // 4, d_model, @@ -1228,19 +1226,20 @@ class Emformer(EncoderInterface): Returns: (Tensor, Tensor): - - output logits, with shape (B, U // 4, D). + - output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). - logits lengths, with shape (B,), without containing the right_context at the end. """ - # TODO: x.shape x = self.encoder_embed(x) - x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! - lengths = x_lens // 4 - assert x.size(0) == lengths.max().item() - output, output_lengths = self.encoder(x, lengths) # (T, N, C) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == x_lens.max().item() + + output, output_lengths = self.encoder(x, x_lens) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -1274,99 +1273,24 @@ class Emformer(EncoderInterface): (default: None) Returns: (Tensor, Tensor): - - output logits, with shape (B, U // 4, D). + - output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). - logits lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. """ x = self.encoder_embed(x) - x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! - lengths = x_lens // 4 - assert x.size(0) == lengths.max().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == x_lens.max().item() + output, output_lengths, output_states = \ - self.encoder.infer(x, lengths, states) # (T, N, C) + self.encoder.infer(x, x_lens, states) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return logits, output_lengths, output_states - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/joiner.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/joiner.py new file mode 120000 index 000000000..81ad47c55 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/model.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/model.py new file mode 120000 index 000000000..a61a0a23f --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/noam.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/noam.py new file mode 100644 index 000000000..e46bf35fb --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/noam.py @@ -0,0 +1,104 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py deleted file mode 100644 index 7d0ad44a6..000000000 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where T' == T // 4. - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__(self, idim: int, odim: int) -> None: - """ - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >= 4, idim >= 7 - odim: - Output dim. The output shape is (N, T // 4, odim) - """ - assert idim >= 7 - super().__init__() - self.conv_1 = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - ) - self.conv_2 = nn.Sequential( - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, T // 4, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) - # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = nn.functional.pad(x, (0, 0, 0, 1), "constant", 0) - # x is of shape (N, 1, T + 1, idim) - x = self.conv_1(x) - # Now x is of shape (N, odim, T // 2, (idim - 1) // 2) - x = nn.functional.pad(x, (0, 0, 0, 1), "constant", 0) - # x is of shape (N, odim, T // 2 + 1, (idim - 1) // 2) - x = self.conv_2(x) - # Now x is of shape (N, odim, T // 4, ((idim - 1) // 2 - 1) // 2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, T // 4, odim) - return x - - -class VggSubsampling(nn.Module): - """Trying to follow the setup described in the following paper: - https://arxiv.org/pdf/1910.09799.pdf - - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where approximates T' = T//4. - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. - - This uses 2 VGG blocks with 2 Conv2d layers each, - subsampling its input by a factor of 4 in the time dimensions. - - Args: - idim: - Input dim. The input shape is (N, T, idim). - Caution: It requires: T >= 4, idim >= 4. - odim: - Output dim. The output shape is (N, T // 4, odim) - """ - super().__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the - # 2nd convolution, so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=False - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear(block_dims[-1] * (idim // 4), odim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, T // 4, odim) - """ - x = x.unsqueeze(1) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index ae93a4c8f..4c9cbba9c 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -255,9 +255,9 @@ def test_emformer_forward(): from emformer import Emformer num_features = 80 output_dim = 1000 - chunk_length = 16 - L, R = 32, 16 - B, D, U = 2, 256, 48 + chunk_length = 8 + L, R = 128, 4 + B, D, U = 2, 256, 80 for use_memory in [True, False]: if use_memory: M = 3 @@ -274,13 +274,14 @@ def test_emformer_forward(): max_memory_size=M, vgg_frontend=False, ) - x = torch.randn(B, U + R, num_features) - x_lens = torch.randint(1, U + R + 1, (B,)) - x_lens[0] = U + R + x = torch.randn(B, U + R + 3, num_features) + x_lens = torch.randint(1, U + R + 3 + 1, (B,)) + x_lens[0] = U + R + 3 logits, output_lengths = model(x, x_lens) assert logits.shape == (B, U // 4, output_dim) assert torch.equal( - output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0) + output_lengths, + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) ) @@ -288,9 +289,9 @@ def test_emformer_infer(): from emformer import Emformer num_features = 80 output_dim = 1000 - chunk_length = 16 + chunk_length = 8 U = chunk_length - L, R = 32, 16 + L, R = 128, 4 B, D = 2, 256 num_chunks = 3 num_encoder_layers = 2 @@ -313,14 +314,15 @@ def test_emformer_infer(): ) states = None for chunk_idx in range(num_chunks): - x = torch.randn(B, U + R, num_features) - x_lens = torch.randint(1, U + R + 1, (B,)) - x_lens[0] = U + R + x = torch.randn(B, U + R + 3, num_features) + x_lens = torch.randint(1, U + R + 3 + 1, (B,)) + x_lens[0] = U + R + 3 logits, output_lengths, states = \ model.infer(x, x_lens, states) assert logits.shape == (B, U // 4, output_dim) assert torch.equal( - output_lengths, torch.clamp(x_lens // 4 - R // 4, min=0) + output_lengths, + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) ) assert len(states) == num_encoder_layers for state in states: @@ -330,7 +332,7 @@ def test_emformer_infer(): assert state[2].shape == (L // 4, B, D) assert torch.equal( state[3], - (chunk_idx + 1) * U // 4 * torch.ones_like(state[3]) + U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]) ) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py deleted file mode 100644 index 338688564..000000000 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_subsampling.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from subsampling import Conv2dSubsampling, VggSubsampling - - -def test_conv2d_subsampling(): - B, idim, odim = 1, 80, 512 - model = Conv2dSubsampling(idim, odim) - for t in range(4, 50): - x = torch.randn(B, t, idim) - outputs = model(x) - assert outputs.shape == (B, t // 4, odim) - - -def test_vgg_subsampling(): - B, idim, odim = 1, 80, 512 - model = VggSubsampling(idim, odim) - for t in range(4, 50): - x = torch.randn(B, t, idim) - outputs = model(x) - assert outputs.shape == (B, t // 4, odim) - - -if __name__ == "__main__": - test_conv2d_subsampling() - test_vgg_subsampling() diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/train.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/train.py new file mode 100755 index 000000000..d7285f4a5 --- /dev/null +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/train.py @@ -0,0 +1,998 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./transducer_emformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir transducer_emformer/exp \ + --full-libri 1 \ + --max-duration 300 +""" + + +import argparse +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from emformer import Emformer +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from noam import Noam +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + setup_logger, + str2bool, +) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--attention-dim", + type=int, + default=512, + help="Attention dim for the Emformer", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads for the Emformer", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feed-forward dimension for the Emformer", + ) + + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of encoder layers for the Emformer", + ) + + parser.add_argument( + "--left-context-length", + type=int, + default=120, + help="Number of frames for the left context in the Emformer", + ) + + parser.add_argument( + "--chunk-length", + type=int, + default=16, + help="Number of frames for each segment in the Emformer", + ) + + parser.add_argument( + "--right-context-length", + type=int, + default=4, + help="Number of frames for right context in the Emformer", + ) + + parser.add_argument( + "--memory-size", + type=int, + default=0, + help="Number of entries in the memory for the Emformer", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_emformer/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - attention_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "log_diagnostics": False, + # parameters for Emformer + "feature_dim": 80, + "subsampling_factor": 4, + "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 512, + # parameters for Noam + "warm_step": 80000, # For the 100h subset, use 20000 + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Emformer( + num_features=params.feature_dim, + output_dim=params.vocab_size, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + left_context_length=params.left_context_length, + chunk_length=params.chunk_length, + right_context_length=params.right_context_length, + max_memory_size=params.memory_size, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + sampler=sampler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Emformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + def maybe_log_gradients(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_weights(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_weight_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_param_relative_changes(): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + deltas = optim_step_and_measure_param_change(model, optimizer) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) + else: + optimizer.step() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + loss.backward() + + maybe_log_weights("train/param_norms") + maybe_log_gradients("train/grad_norms") + maybe_log_param_relative_changes() + + optimizer.zero_grad() + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + params.warm_step = 20000 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 2d1b90f7587411c1cc891110deb9ad4924adbd9a Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 8 Apr 2022 10:59:39 +0800 Subject: [PATCH 179/234] update the docs of Emformer class in emformer.py --- .../emformer.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index edba2e0b3..91bb571c5 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -1,14 +1,14 @@ import math -from typing import List, Optional, Tuple import warnings +from typing import List, Optional, Tuple import torch -from torch import nn - -from icefall.utils import make_pad_mask +import torch.nn as nn from encoder_interface import EncoderInterface from subsampling import Conv2dSubsampling, VggSubsampling +from icefall.utils import make_pad_mask + def _get_activation_module(activation: str) -> nn.Module: if activation == "relu": @@ -1213,12 +1213,12 @@ class Emformer(EncoderInterface): B: batch size; D: feature dimension; - U: length of utterance. + T: length of utterance. Args: x (torch.Tensor): Utterance frames right-padded with right context frames, - with shape (B, U + right_context_length, D). + with shape (B, T, D). x_lens (torch.Tensor): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, containing the @@ -1226,7 +1226,8 @@ class Emformer(EncoderInterface): Returns: (Tensor, Tensor): - - output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). + - output logits, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. - logits lengths, with shape (B,), without containing the right_context at the end. """ @@ -1257,12 +1258,12 @@ class Emformer(EncoderInterface): B: batch size; D: feature dimension; - U: length of utterance. + T: length of utterance. Args: x (torch.Tensor): Utterance frames right-padded with right context frames, - with shape (B, U + right_context_length, D). + with shape (B, T, D). lengths (torch.Tensor): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, containing the @@ -1273,7 +1274,8 @@ class Emformer(EncoderInterface): (default: None) Returns: (Tensor, Tensor): - - output logits, with shape (B, ((U - 1) // 2 - 1) // 2, D). + - output logits, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. - logits lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. From 6ee32cf7afd110783b5872e431e30308583abb21 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Apr 2022 16:10:06 +0800 Subject: [PATCH 180/234] Set new scheduler --- .../ASR/pruned_transducer_stateless2/train.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 306a2195b..e06db45c0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -28,15 +28,17 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 300 \ - --initial-lr 0.002 \ - --lr-decay-steps 10000 \ - --num-lr-decays 4 + --initial-lr 0.003 \ + --lr-begin-steps 20000 \ + --lr-end-steps 50000 + """ import argparse import logging +import math import warnings from pathlib import Path from shutil import copyfile @@ -147,22 +149,22 @@ def get_parser(): parser.add_argument( "--initial-lr", type=float, - default=0.002, + default=0.003, help="The initial learning rate", ) parser.add_argument( - "--lr-num-steps", + "--lr-begin-steps", type=float, - default=3000, - help="Number of steps before we start to significantly decay the learning rate", + default=20000, + help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( - "--lr-power", + "--lr-end-steps", type=float, - default=0.75, - help="Power in LR-setting rule", + default=50000, + help="Number of steps that affects how rapidly the learning rate finally decreases" ) parser.add_argument( @@ -783,7 +785,8 @@ def run(rank, world_size, args): lr=params.initial_lr) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: ((params.lr_num_steps/(step + params.lr_num_steps)) ** params.lr_power)) + lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * + math.exp(-step / params.lr_end_steps))) if checkpoints and "optimizer" in checkpoints: From f587cd527dd0075349d2a2f3502d3f62945679be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Apr 2022 16:24:21 +0800 Subject: [PATCH 181/234] Change exponential part of lrate to be epoch based --- .../ASR/pruned_transducer_stateless2/train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index e06db45c0..d5da5d0e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -27,10 +27,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ - --max-duration 300 \ - --initial-lr 0.003 \ - --lr-begin-steps 20000 \ - --lr-end-steps 50000 + --max-duration 300 + """ @@ -161,10 +159,10 @@ def get_parser(): ) parser.add_argument( - "--lr-end-steps", + "--lr-end-epochs", type=float, - default=50000, - help="Number of steps that affects how rapidly the learning rate finally decreases" + default=10, + help="Number of epochs that affects how rapidly the learning rate finally decreases" ) parser.add_argument( @@ -783,10 +781,13 @@ def run(rank, world_size, args): optimizer = Eve( model.parameters(), lr=params.initial_lr) + + # The `epoch` variable in the lambda expression binds to the value below + # in `for epoch in range(params.start_epoch, params.num_epochs):`. scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - math.exp(-step / params.lr_end_steps))) + math.exp(-epoch / params.lr_end_epochs))) if checkpoints and "optimizer" in checkpoints: From 0f8ee68af22657c90c5e2762a5e48e1d09b7ce0c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 8 Apr 2022 16:53:42 +0800 Subject: [PATCH 182/234] Fix bug --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d5da5d0e9..038469282 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -784,6 +784,7 @@ def run(rank, world_size, args): # The `epoch` variable in the lambda expression binds to the value below # in `for epoch in range(params.start_epoch, params.num_epochs):`. + epoch = 0 scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * From d58002c4146e24d3e19b3ece0ce90ef29c32bdff Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 8 Apr 2022 20:31:32 +0800 Subject: [PATCH 183/234] update emformer.py --- .pre-commit-config.yaml | 1 + .../emformer.py | 406 ++++++++++-------- 2 files changed, 233 insertions(+), 174 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b59784dbf..62d34864b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: hooks: - id: black args: [--line-length=80] + additional_dependencies: ['click==8.0.1'] - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 91bb571c5..4ba19ebae 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -1,3 +1,22 @@ +# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# It is modified based on +# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. + import math import warnings from typing import List, Optional, Tuple @@ -22,29 +41,32 @@ def _get_activation_module(activation: str) -> nn.Module: def _get_weight_init_gains( - weight_init_scale_strategy: Optional[str], - num_layers: int + weight_init_scale_strategy: Optional[str], num_layers: int ) -> List[Optional[float]]: if weight_init_scale_strategy is None: return [None for _ in range(num_layers)] elif weight_init_scale_strategy == "depthwise": - return [1.0 / math.sqrt(layer_idx + 1) - for layer_idx in range(num_layers)] + return [ + 1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers) + ] elif weight_init_scale_strategy == "constant": return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] else: - raise ValueError(f"Unsupported weight_init_scale_strategy value" - f"{weight_init_scale_strategy}") + raise ValueError( + f"Unsupported weight_init_scale_strategy value" + f"{weight_init_scale_strategy}" + ) def _gen_attention_mask_block( col_widths: List[int], col_mask: List[bool], num_rows: int, - device: torch.device + device: torch.device, ) -> torch.Tensor: - assert len(col_widths) == len(col_mask), ( - "Length of col_widths must match that of col_mask") + assert len(col_widths) == len( + col_mask + ), "Length of col_widths must match that of col_mask" mask_block = [ torch.ones(num_rows, col_width, device=device) @@ -99,9 +121,7 @@ class EmformerAttention(nn.Module): self.scaling = (self.embed_dim // self.nhead) ** -0.5 - self.emb_to_key_value = nn.Linear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) @@ -119,7 +139,7 @@ class EmformerAttention(nn.Module): attention_mask: torch.Tensor, padding_mask: Optional[torch.Tensor], ) -> torch.Tensor: - """ Given the entire attention weights, mask out unecessary connections + """Given the entire attention weights, mask out unecessary connections and optionally with padding positions, to obtain underlying chunk-wise attention probabilities. @@ -154,7 +174,7 @@ class EmformerAttention(nn.Module): ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - self.negative_inf + self.negative_inf, ) attention_weights_float = attention_weights_float.view( B * self.nhead, Q, -1 @@ -164,9 +184,7 @@ class EmformerAttention(nn.Module): attention_weights_float, dim=-1 ).type_as(attention_weights) attention_probs = nn.functional.dropout( - attention_probs, - p=float(self.dropout), - training=self.training + attention_probs, p=float(self.dropout), training=self.training ) return attention_probs @@ -181,7 +199,7 @@ class EmformerAttention(nn.Module): left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ Underlying chunk-wise attention implementation. + """Underlying chunk-wise attention implementation. L: length of left_context; S: length of summary; @@ -242,14 +260,28 @@ class EmformerAttention(nn.Module): # [mems, right context, left context, uttrance] M = memory.size(0) R = right_context.size(0) - key = torch.cat([key[:M + R], left_context_key, key[M + R:]]) - value = torch.cat([value[:M + R], left_context_val, value[M + R:]]) + right_context_end_idx = M + R + key = torch.cat( + [ + key[:right_context_end_idx], + left_context_key, + key[right_context_end_idx:], + ] + ) + value = torch.cat( + [ + value[:right_context_end_idx], + left_context_val, + value[right_context_end_idx:], + ] + ) # Compute attention weights from query, key, and value. reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view( - -1, B * self.nhead, self.embed_dim // self.nhead - ).transpose(0, 1) for tensor in [query, key, value] + tensor.contiguous() + .view(-1, B * self.nhead, self.embed_dim // self.nhead) + .transpose(0, 1) + for tensor in [query, key, value] ] attention_weights = torch.bmm( reshaped_query * self.scaling, reshaped_key.transpose(1, 2) @@ -272,18 +304,21 @@ class EmformerAttention(nn.Module): attention = torch.bmm(attention_probs, reshaped_value) Q = query.size(0) assert attention.shape == ( - B * self.nhead, Q, self.embed_dim // self.nhead, + B * self.nhead, + Q, + self.embed_dim // self.nhead, ) - attention = attention.transpose(0, 1).contiguous().view( - Q, B, self.embed_dim + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) ) # Apply output projection. outputs = self.out_proj(attention) S = summary.size(0) - output_right_context_utterance = outputs[:Q - S] - output_memory = outputs[Q - S:] + summary_start_idx = Q - S + output_right_context_utterance = outputs[:summary_start_idx] + output_memory = outputs[summary_start_idx:] if self.tanh_on_mem: output_memory = torch.tanh(output_memory) else: @@ -331,15 +366,14 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - output_right_context_utterance, output_memory, _, _ = \ - self._forward_impl( - utterance, - lengths, - right_context, - summary, - memory, - attention_mask - ) + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( + utterance, lengths, right_context, summary, memory, attention_mask + ) return output_right_context_utterance, output_memory[:-1] @torch.jit.export @@ -394,29 +428,38 @@ class EmformerAttention(nn.Module): # query: [right context, utterance, summary] Q = right_context.size(0) + utterance.size(0) + summary.size(0) # key, value: [memory, right context, left context, uttrance] - KV = memory.size(0) + right_context.size(0) + \ - left_context_key.size(0) + utterance.size(0) - attention_mask = torch.zeros( - Q, KV - ).to(dtype=torch.bool, device=utterance.device) + KV = ( + memory.size(0) + + right_context.size(0) + + left_context_key.size(0) + + utterance.size(0) + ) + attention_mask = torch.zeros(Q, KV).to( + dtype=torch.bool, device=utterance.device + ) # Disallow attention bettween the summary vector with the memory bank - attention_mask[-1, :memory.size(0)] = True - output_right_context_utterance, output_memory, key, value = \ - self._forward_impl( - utterance, - lengths, - right_context, - summary, - memory, - attention_mask, - left_context_key=left_context_key, - left_context_val=left_context_val, - ) + attention_mask[-1, : memory.size(0)] = True + ( + output_right_context_utterance, + output_memory, + key, + value, + ) = self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + right_context_end_idx = memory.size(0) + right_context.size(0) return ( output_right_context_utterance, output_memory, - key[memory.size(0) + right_context.size(0):], - value[memory.size(0) + right_context.size(0):], + key[right_context_end_idx:], + value[right_context_end_idx:], ) @@ -499,9 +542,7 @@ class EmformerLayer(nn.Module): self.use_memory = max_memory_size > 0 def _init_state( - self, - batch_size: int, - device: Optional[torch.device] + self, batch_size: int, device: Optional[torch.device] ) -> List[torch.Tensor]: """Initialize states with zeros.""" empty_memory = torch.zeros( @@ -519,8 +560,7 @@ class EmformerLayer(nn.Module): return [empty_memory, left_context_key, left_context_val, past_length] def _unpack_state( - self, - state: List[torch.Tensor] + self, state: List[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Unpack cached states including: 1) output memory from previous chunks in the lower layer; @@ -532,11 +572,13 @@ class EmformerLayer(nn.Module): past_memory_length = min( self.max_memory_size, math.ceil(past_length / self.chunk_length) ) - pre_memory = state[0][self.max_memory_size - past_memory_length:] - left_context_key = \ - state[1][self.left_context_length - past_left_context_length:] - left_context_val = \ - state[2][self.left_context_length - past_left_context_length:] + memory_start_idx = self.max_memory_size - past_memory_length + pre_memory = state[0][memory_start_idx:] + left_context_start_idx = ( + self.left_context_length - past_left_context_length + ) + left_context_key = state[1][left_context_start_idx:] + left_context_val = state[2][left_context_start_idx:] return pre_memory, left_context_key, left_context_val def _pack_state( @@ -556,40 +598,46 @@ class EmformerLayer(nn.Module): new_memory = torch.cat([state[0], memory]) new_key = torch.cat([state[1], next_key]) new_val = torch.cat([state[2], next_val]) - state[0] = new_memory[new_memory.size(0) - self.max_memory_size:] - state[1] = new_key[new_key.size(0) - self.left_context_length:] - state[2] = new_val[new_val.size(0) - self.left_context_length:] + memory_start_idx = new_memory.size(0) - self.max_memory_size + state[0] = new_memory[memory_start_idx:] + key_start_idx = new_key.size(0) - self.left_context_length + state[1] = new_key[key_start_idx:] + val_start_idx = new_val.size(0) - self.left_context_length + state[2] = new_val[val_start_idx:] state[3] = state[3] + update_length return state def _apply_pre_attention_layer_norm( self, utterance: torch.Tensor, right_context: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply layer normalization before attention. """ + """Apply layer normalization before attention.""" layer_norm_input = self.layer_norm_input( torch.cat([right_context, utterance]) ) - layer_norm_utterance = layer_norm_input[right_context.size(0):] - layer_norm_right_context = layer_norm_input[:right_context.size(0)] + right_context_end_idx = right_context.size(0) + layer_norm_utterance = layer_norm_input[right_context_end_idx:] + layer_norm_right_context = layer_norm_input[:right_context_end_idx] return layer_norm_utterance, layer_norm_right_context def _apply_post_attention_ffn_layer_norm( self, output_right_context_utterance: torch.Tensor, utterance: torch.Tensor, - right_context: torch.Tensor + right_context: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply feed forward and layer normalization after attention.""" # Apply residual connection between input and attention output. - result = self.dropout(output_right_context_utterance) + \ - torch.cat([right_context, utterance]) + result = self.dropout(output_right_context_utterance) + torch.cat( + [right_context, utterance] + ) # Apply feedforward module and residual connection. result = self.pos_ff(result) + result # Apply layer normalization for output. result = self.layer_norm_output(result) - output_utterance = result[right_context.size(0):] - output_right_context = result[:right_context.size(0)] + right_context_end_idx = right_context.size(0) + output_utterance = result[right_context_end_idx:] + output_right_context = result[:right_context_end_idx] return output_utterance, output_right_context def _apply_attention_forward( @@ -600,16 +648,16 @@ class EmformerLayer(nn.Module): memory: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply attention in non-infer mode. """ + """Apply attention in non-infer mode.""" if attention_mask is None: raise ValueError( "attention_mask must be not None in non-infer mode. " ) if self.use_memory: - summary = self.summary_op( - utterance.permute(1, 2, 0) - ).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) else: summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device @@ -646,27 +694,32 @@ class EmformerLayer(nn.Module): """ if state is None: state = self._init_state(utterance.size(1), device=utterance.device) - pre_memory, left_context_key, left_context_val = \ - self._unpack_state(state) + pre_memory, left_context_key, left_context_val = self._unpack_state( + state + ) if self.use_memory: - summary = self.summary_op( - utterance.permute(1, 2, 0) - ).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) summary = summary[:1] else: summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) - output_right_context_utterance, output_memory, next_key, next_val = \ - self.attention.infer( - utterance=utterance, - lengths=lengths, - right_context=right_context, - summary=summary, - memory=pre_memory, - left_context_key=left_context_key, - left_context_val=left_context_val, - ) + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = self.attention.infer( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) state = self._pack_state( next_key, next_val, utterance.size(0), memory, state ) @@ -718,20 +771,22 @@ class EmformerLayer(nn.Module): layer_norm_utterance, layer_norm_right_context, ) = self._apply_pre_attention_layer_norm(utterance, right_context) - output_right_context_utterance, output_memory = \ - self._apply_attention_forward( - layer_norm_utterance, - lengths, - layer_norm_right_context, - memory, - attention_mask, - ) - output_utterance, output_right_context = \ - self._apply_post_attention_ffn_layer_norm( - output_right_context_utterance, - utterance, - right_context - ) + ( + output_right_context_utterance, + output_memory, + ) = self._apply_attention_forward( + layer_norm_utterance, + lengths, + layer_norm_right_context, + memory, + attention_mask, + ) + ( + output_utterance, + output_right_context, + ) = self._apply_post_attention_ffn_layer_norm( + output_right_context_utterance, utterance, right_context + ) return output_utterance, output_right_context, output_memory @torch.jit.export @@ -745,63 +800,66 @@ class EmformerLayer(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: """Forward pass for inference. - 1) Apply layer normalization on input utterance and right context - before attention; - 2) Apply attention module with cached state, compute updated utterance, - right context, and memory, and update state; - 3) Apply feed forward module and layer normalization on output utterance - and right context. + 1) Apply layer normalization on input utterance and right context + before attention; + 2) Apply attention module with cached state, compute updated utterance, + right context, and memory, and update state; + 3) Apply feed forward module and layer normalization on output + utterance and right context. - B: batch size; - D: embedding dimension; - R: length of right_context; - U: length of utterance; - M: length of memory. + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. - Args: - utterance (torch.Tensor): - Utterance frames, with shape (U, B, D). - lengths (torch.Tensor): - With shape (B,) and i-th element representing - number of valid frames for i-th batch element in utterance. - right_context (torch.Tensor): - Right context frames, with shape (R, B, D). - memory (torch.Tensor): - Memory elements, with shape (M, B, D). - state (List[torch.Tensor], optional): - List of tensors representing layer internal state generated in - preceding computation. (default=None) + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + state (List[torch.Tensor], optional): + List of tensors representing layer internal state generated in + preceding computation. (default=None) - Returns: - (Tensor, Tensor, List[torch.Tensor], Tensor): - - output utterance, with shape (U, B, D); - - output right_context, with shape (R, B, D); - - output memory, with shape (1, B, D) or (0, B, D). - - output state. + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output memory, with shape (1, B, D) or (0, B, D). + - output state. """ ( layer_norm_utterance, layer_norm_right_context, ) = self._apply_pre_attention_layer_norm(utterance, right_context) - output_right_context_utterance, output_memory, output_state = \ - self._apply_attention_infer( - layer_norm_utterance, - lengths, - layer_norm_right_context, - memory, - state - ) - output_utterance, output_right_context = \ - self._apply_post_attention_ffn_layer_norm( - output_right_context_utterance, - utterance, - right_context - ) + ( + output_right_context_utterance, + output_memory, + output_state, + ) = self._apply_attention_infer( + layer_norm_utterance, + lengths, + layer_norm_right_context, + memory, + state, + ) + ( + output_utterance, + output_right_context, + ) = self._apply_post_attention_ffn_layer_norm( + output_right_context_utterance, utterance, right_context + ) return ( output_utterance, output_right_context, output_memory, - output_state + output_state, ) @@ -895,7 +953,7 @@ class EmformerEncoder(nn.Module): self.max_memory_size = max_memory_size def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: - """Hard copy each chunk's right context and concat them. """ + """Hard copy each chunk's right context and concat them.""" T = x.shape[0] num_segs = math.ceil( (T - self.right_context_length) / self.chunk_length @@ -905,7 +963,8 @@ class EmformerEncoder(nn.Module): start = (seg_idx + 1) * self.chunk_length end = start + self.right_context_length right_context_blocks.append(x[start:end]) - right_context_blocks.append(x[T - self.right_context_length:]) + last_right_context_start_idx = T - self.right_context_length + right_context_blocks.append(x[last_right_context_start_idx:]) return torch.cat(right_context_blocks) def _gen_attention_mask_col_widths( @@ -981,8 +1040,9 @@ class EmformerEncoder(nn.Module): num_cols = 9 # right context and utterance both attend to memory, right context, # utterance - right_context_utterance_cols_mask = \ - [idx in [1, 4, 7] for idx in range(num_cols)] + right_context_utterance_cols_mask = [ + idx in [1, 4, 7] for idx in range(num_cols) + ] # summary attends to right context, utterance summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] masks_to_concat = [right_context_mask, utterance_mask, summary_mask] @@ -990,8 +1050,9 @@ class EmformerEncoder(nn.Module): num_cols = 6 # right context and utterance both attend to right context and # utterance - right_context_utterance_cols_mask = \ - [idx in [1, 4] for idx in range(num_cols)] + right_context_utterance_cols_mask = [ + idx in [1, 4] for idx in range(num_cols) + ] summary_cols_mask = None masks_to_concat = [right_context_mask, utterance_mask] @@ -1002,7 +1063,7 @@ class EmformerEncoder(nn.Module): col_widths, right_context_utterance_cols_mask, self.right_context_length, - utterance.device + utterance.device, ) right_context_mask.append(right_context_mask_block) @@ -1053,13 +1114,13 @@ class EmformerEncoder(nn.Module): right_context at the end. """ right_context = self._gen_right_context(x) - utterance = x[:x.size(0) - self.right_context_length] + utterance = x[: x.size(0) - self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op( - utterance.permute(1, 2, 0) - ).permute(2, 0, 1)[:-1] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1159,13 +1220,9 @@ class Emformer(EncoderInterface): self.subsampling_factor = subsampling_factor self.right_context_length = right_context_length if subsampling_factor != 4: - raise NotImplementedError( - "Support only 'subsampling_factor=4'." - ) + raise NotImplementedError("Support only 'subsampling_factor=4'.") if chunk_length % 4 != 0: - raise NotImplementedError( - "chunk_length must be a mutiple of 4." - ) + raise NotImplementedError("chunk_length must be a mutiple of 4.") if left_context_length != 0 and left_context_length % 4 != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of 4." @@ -1289,8 +1346,9 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths, output_states = \ - self.encoder.infer(x, x_lens, states) # (T, N, C) + output, output_lengths, output_states = self.encoder.infer( + x, x_lens, states + ) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) From 3e131891a2a466e279b6a7492361393d7f1a2093 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 8 Apr 2022 20:43:54 +0800 Subject: [PATCH 184/234] update test_emformer.py --- .../test_emformer.py | 66 ++++++++++--------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 4c9cbba9c..56cf2035e 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -61,16 +61,20 @@ def test_emformer_attention_infer(): left_context_key = torch.randn(L, B, D) left_context_val = torch.randn(L, B, D) - output_right_context_utterance, output_memory, next_key, next_val = \ - attention.infer( - utterance, - lengths, - right_context, - summary, - memory, - left_context_key, - left_context_val, - ) + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = attention.infer( + utterance, + lengths, + right_context, + summary, + memory, + left_context_key, + left_context_val, + ) assert output_right_context_utterance.shape == (R + U, B, D) assert output_memory.shape == (S, B, D) assert next_key.shape == (L + U, B, D) @@ -98,7 +102,7 @@ def test_emformer_layer_forward(): chunk_length=chunk_length, left_context_length=L, max_memory_size=M, - ) + ) Q, KV = R + U + S, M + R + U utterance = torch.randn(U, B, D) @@ -141,7 +145,7 @@ def test_emformer_layer_infer(): chunk_length=chunk_length, left_context_length=L, max_memory_size=M, - ) + ) utterance = torch.randn(U, B, D) lengths = torch.randint(1, U + 1, (B,)) @@ -149,14 +153,18 @@ def test_emformer_layer_infer(): right_context = torch.randn(R, B, D) memory = torch.randn(M, B, D) state = None - output_utterance, output_right_context, output_memory, output_state = \ - layer.infer( - utterance, - lengths, - right_context, - memory, - state, - ) + ( + output_utterance, + output_right_context, + output_memory, + output_state, + ) = layer.infer( + utterance, + lengths, + right_context, + memory, + state, + ) assert output_utterance.shape == (U, B, D) assert output_right_context.shape == (R, B, D) if use_memory: @@ -200,9 +208,7 @@ def test_emformer_encoder_forward(): output, output_lengths = encoder(x, lengths) assert output.shape == (U, B, D) - assert torch.equal( - output_lengths, torch.clamp(lengths - R, min=0) - ) + assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) def test_emformer_encoder_infer(): @@ -236,8 +242,7 @@ def test_emformer_encoder_infer(): x = torch.randn(U + R, B, D) lengths = torch.randint(1, U + R + 1, (B,)) lengths[0] = U + R - output, output_lengths, states = \ - encoder.infer(x, lengths, states) + output, output_lengths, states = encoder.infer(x, lengths, states) assert output.shape == (U, B, D) assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) assert len(states) == num_encoder_layers @@ -253,6 +258,7 @@ def test_emformer_encoder_infer(): def test_emformer_forward(): from emformer import Emformer + num_features = 80 output_dim = 1000 chunk_length = 8 @@ -281,12 +287,13 @@ def test_emformer_forward(): assert logits.shape == (B, U // 4, output_dim) assert torch.equal( output_lengths, - torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0), ) def test_emformer_infer(): from emformer import Emformer + num_features = 80 output_dim = 1000 chunk_length = 8 @@ -317,12 +324,11 @@ def test_emformer_infer(): x = torch.randn(B, U + R + 3, num_features) x_lens = torch.randint(1, U + R + 3 + 1, (B,)) x_lens[0] = U + R + 3 - logits, output_lengths, states = \ - model.infer(x, x_lens, states) + logits, output_lengths, states = model.infer(x, x_lens, states) assert logits.shape == (B, U // 4, output_dim) assert torch.equal( output_lengths, - torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0) + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0), ) assert len(states) == num_encoder_layers for state in states: @@ -332,7 +338,7 @@ def test_emformer_infer(): assert state[2].shape == (L // 4, B, D) assert torch.equal( state[3], - U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]) + U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]), ) From db72aee1f0e4987bf79b0578967f1c45be562dbc Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Apr 2022 18:15:56 +0800 Subject: [PATCH 185/234] Set 2n rule.. --- .../ASR/pruned_transducer_stateless2/train.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 038469282..92509f4ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -154,15 +154,17 @@ def get_parser(): parser.add_argument( "--lr-begin-steps", type=float, - default=20000, + default=25000, help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( "--lr-end-epochs", type=float, - default=10, - help="Number of epochs that affects how rapidly the learning rate finally decreases" + default=-1, + help="""Number of epochs that affects how rapidly the learning rate finally decreases; + if -1, will be set the same as --num-epochs + """ ) parser.add_argument( @@ -783,12 +785,14 @@ def run(rank, world_size, args): lr=params.initial_lr) # The `epoch` variable in the lambda expression binds to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. + # in `for epoch in range(params.start_epoch, params.num_epochs):`. But set it to 0 + # here to avoid crash in constructor. epoch = 0 + lr_end_epochs = params.lr_end_epochs if params.lr_end_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - math.exp(-epoch / params.lr_end_epochs))) + ((epoch + lr_end_epochs) / lr_end_epochs) ** -2.0)) if checkpoints and "optimizer" in checkpoints: From 4d41ee0caad855354bf95b6f1cab87072060a974 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 9 Apr 2022 18:37:03 +0800 Subject: [PATCH 186/234] Implement 2o schedule --- .../ASR/pruned_transducer_stateless2/optim.py | 12 ------------ .../ASR/pruned_transducer_stateless2/train.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index a2e0463da..e47c08657 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -100,9 +100,6 @@ class Eve(Optimizer): if p.grad is None: continue - - - # Perform optimization step grad = p.grad if grad.is_sparse: @@ -144,12 +141,3 @@ class Eve(Optimizer): p.addcdiv_(exp_avg, denom, value=-step_size) return loss - -# Note on avg-change per epoch.. -# suppose epoch is 4k iters. -# if avg-change as rms(diff) / rms(params) equals 0.2, and rms(params) = 0.1, -# then rm(diff) 0.1 * 0.2, var(diff) = (0.1 * 0.2)**2, = 0.0004. So var(diff per minibatch) -# = 0.0004 / 4000 = 1e-07, rms(diff per minibatch) = 3.16e-04. So LR would be 3e-04. -# Suggested lr_schedule? -# -# .. 6e-05 is 1/5 of that... diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 92509f4ec..a114dd8f1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -154,15 +154,15 @@ def get_parser(): parser.add_argument( "--lr-begin-steps", type=float, - default=25000, + default=5000, help="Number of steps that affects how rapidly the learning rate initially decreases" ) parser.add_argument( - "--lr-end-epochs", + "--lr-epochs", type=float, default=-1, - help="""Number of epochs that affects how rapidly the learning rate finally decreases; + help="""Number of epochs for purposes of the learning-rate schedule; if -1, will be set the same as --num-epochs """ ) @@ -784,15 +784,15 @@ def run(rank, world_size, args): model.parameters(), lr=params.initial_lr) - # The `epoch` variable in the lambda expression binds to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. But set it to 0 + # The `epoch` variable in the lambda expression picks up to the value below + # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 # here to avoid crash in constructor. epoch = 0 - lr_end_epochs = params.lr_end_epochs if params.lr_end_epochs > 0 else params.num_epochs + lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (((step + params.lr_begin_steps) / params.lr_begin_steps) ** -0.5 * - ((epoch + lr_end_epochs) / lr_end_epochs) ** -2.0)) + lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 * + ((epoch + lr_epochs) / lr_epochs) ** -0.5)) if checkpoints and "optimizer" in checkpoints: From da50525ca5f1cf5bb655adcac0e8ad5231aa7b5c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 13:25:40 +0800 Subject: [PATCH 187/234] Make lrate rule more symmetric --- .../ASR/pruned_transducer_stateless2/train.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a114dd8f1..a8aaa4dde 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -148,22 +148,22 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate", + help="The initial learning rate. This value should not need to be changed.", ) parser.add_argument( - "--lr-begin-steps", + "--lr-steps", type=float, default=5000, - help="Number of steps that affects how rapidly the learning rate initially decreases" + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""" ) parser.add_argument( "--lr-epochs", type=float, - default=-1, - help="""Number of epochs for purposes of the learning-rate schedule; - if -1, will be set the same as --num-epochs + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. """ ) @@ -788,11 +788,10 @@ def run(rank, world_size, args): # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 # here to avoid crash in constructor. epoch = 0 - lr_epochs = params.lr_epochs if params.lr_epochs > 0 else params.num_epochs scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, - lambda step: (((step**2 + params.lr_begin_steps**2) / params.lr_begin_steps**2) ** -0.25 * - ((epoch + lr_epochs) / lr_epochs) ** -0.5)) + lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * + (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)) if checkpoints and "optimizer" in checkpoints: From 82d58629eaa54b64b32e68cb44d519bb58e530e6 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 13:50:31 +0800 Subject: [PATCH 188/234] Implement 2p version of learning rate schedule. --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a8aaa4dde..73ba17a71 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -791,7 +791,7 @@ def run(rank, world_size, args): scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * - (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25)) + (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))) if checkpoints and "optimizer" in checkpoints: From d1e4ae788dcddbefd3840c3f5bbc598ec7e225b9 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 15:25:27 +0800 Subject: [PATCH 189/234] Refactor how learning rate is set. --- .../ASR/pruned_transducer_stateless2/optim.py | 151 +++++++++++++++++- .../ASR/pruned_transducer_stateless2/train.py | 43 ++--- icefall/checkpoint.py | 11 +- 3 files changed, 174 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index e47c08657..4f7392d3a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -16,7 +16,7 @@ import random -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -141,3 +141,152 @@ class Eve(Optimizer): p.addcdiv_(exp_avg, denom, value=-step_size) return loss + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault('initial_lr', group['lr']) + + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {'base_lrs': self.base_lrs, + 'epoch': self.epoch, + 'batch': self.batch} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """ Return last computed learning rate by current scheduler. Will be a list of float. + """ + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate. + """ + if is_verbose: + print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' + f' of group {group} to {lr:.4e}.') + + +class Eden(LRScheduler): + """ + Eden scheduler. + lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) + + E.g. suggest initial-lr = 0.003 (passed to optimizer). + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6. + """ + def __init__(self, optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + verbose: bool = False): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + + def get_lr(self): + factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * + (((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) + return [ x * factor for x in self.base_lrs ] + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = Eve(m.parameters(), lr=0.003) + + scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + print("last lr = ", scheduler.get_last_lr()) + print("state dict = ", scheduler.state_dict()) + +if __name__ == '__main__': + _test_eden() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 73ba17a71..ddd2e8fb7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -40,7 +40,7 @@ import math import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import k2 import sentencepiece as spm @@ -55,7 +55,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eve +from optim import Eve, Eden from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -74,6 +74,7 @@ from icefall.utils import ( str2bool, ) +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): parser = argparse.ArgumentParser( @@ -152,7 +153,7 @@ def get_parser(): ) parser.add_argument( - "--lr-steps", + "--lr-batches", type=float, default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. @@ -378,7 +379,7 @@ def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. @@ -443,7 +444,7 @@ def save_checkpoint( params: AttributeDict, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, rank: int = 0, ) -> None: @@ -593,7 +594,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, + scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -656,17 +657,15 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. loss.backward() + scheduler.step_batch(params.batch_idx_train) optimizer.step() optimizer.zero_grad() - scheduler.step() if params.print_diagnostics and batch_idx == 5: return - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): + if (params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0): params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, @@ -686,13 +685,17 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" ) if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr) + loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) @@ -784,14 +787,7 @@ def run(rank, world_size, args): model.parameters(), lr=params.initial_lr) - # The `epoch` variable in the lambda expression picks up to the value below - # in `for epoch in range(params.start_epoch, params.num_epochs):`. Set it to 0 - # here to avoid crash in constructor. - epoch = 0 - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, - lambda step: (((step**2 + params.lr_steps**2) / params.lr_steps**2) ** -0.25 * - (((epoch**2 + params.lr_epochs**2) / params.lr_epochs**2) ** -0.25))) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: @@ -854,19 +850,14 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - params.cur_epoch = epoch train_one_epoch( diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c95..c0d4b3968 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -28,15 +28,18 @@ from lhotse.dataset.sampling.base import CutSampler from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler +# use duck typing for LRScheduler since we have different possibilities, see +# our class LRScheduler. +LRSchedulerType = object + def save_checkpoint( filename: Path, model: Union[nn.Module, DDP], params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, @@ -89,7 +92,7 @@ def load_checkpoint( filename: Path, model: nn.Module, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, strict: bool = False, @@ -167,7 +170,7 @@ def save_checkpoint_with_global_batch_idx( model: Union[nn.Module, DDP], params: Optional[Dict[str, Any]] = None, optimizer: Optional[Optimizer] = None, - scheduler: Optional[_LRScheduler] = None, + scheduler: Optional[LRSchedulerType] = None, scaler: Optional[GradScaler] = None, sampler: Optional[CutSampler] = None, rank: int = 0, From 962cf868c960125170802294c79338adec391ffa Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 10 Apr 2022 15:31:46 +0800 Subject: [PATCH 190/234] Fix import --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index ddd2e8fb7..62dc825b6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -45,6 +45,7 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import sentencepiece as spm import torch +import optim # from . import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule From 8129470586be92cda3708cc49945fbfbd71ca176 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 10 Apr 2022 20:24:20 +0800 Subject: [PATCH 191/234] first upload the conv_emformer_transducer recipe, integrating convolution module into emformer layers. --- .../asr_datamodule.py | 1 + .../conv_emformer_transducer/beam_search.py | 1 + .../ASR/conv_emformer_transducer/decode.py | 549 +++++++ .../ASR/conv_emformer_transducer/decoder.py | 1 + .../ASR/conv_emformer_transducer/emformer.py | 1445 +++++++++++++++++ .../encoder_interface.py | 1 + .../ASR/conv_emformer_transducer/joiner.py | 1 + .../ASR/conv_emformer_transducer/model.py | 1 + .../ASR/conv_emformer_transducer/noam.py | 104 ++ .../conv_emformer_transducer/subsampling.py | 1 + .../conv_emformer_transducer/test_emformer.py | 359 ++++ .../ASR/conv_emformer_transducer/train.py | 1006 ++++++++++++ 12 files changed, 3470 insertions(+) create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/asr_datamodule.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/beam_search.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer/decode.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/decoder.py create mode 100644 egs/librispeech/ASR/conv_emformer_transducer/emformer.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/encoder_interface.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/joiner.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/model.py create mode 100644 egs/librispeech/ASR/conv_emformer_transducer/noam.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer/subsampling.py create mode 100644 egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer/train.py diff --git a/egs/librispeech/ASR/conv_emformer_transducer/asr_datamodule.py b/egs/librispeech/ASR/conv_emformer_transducer/asr_datamodule.py new file mode 120000 index 000000000..b4e5427e0 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/beam_search.py b/egs/librispeech/ASR/conv_emformer_transducer/beam_search.py new file mode 120000 index 000000000..227d2247c --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/decode.py b/egs/librispeech/ASR/conv_emformer_transducer/decode.py new file mode 100755 index 000000000..c40b01dfa --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/decode.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./transducer_emformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./transducer_emformer/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--avg-last-n", + type=int, + default=0, + help="""If positive, --epoch and --avg are ignored and it + will use the last n checkpoints exp_dir/checkpoint-xxx.pt + where xxx is the number of processed batches while + saving that checkpoint. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg_last_n > 0: + filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer/decoder.py b/egs/librispeech/ASR/conv_emformer_transducer/decoder.py new file mode 120000 index 000000000..0d5f10dc0 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py new file mode 100644 index 000000000..49e59bd00 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -0,0 +1,1445 @@ +# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# It is modified based on +# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. + +import math +import warnings +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from subsampling import Conv2dSubsampling, VggSubsampling + +from icefall.utils import make_pad_mask + + +def _gen_attention_mask_block( + col_widths: List[int], + col_mask: List[bool], + num_rows: int, + device: torch.device, +) -> torch.Tensor: + assert len(col_widths) == len( + col_mask + ), "Length of col_widths must match that of col_mask" + + mask_block = [ + torch.ones(num_rows, col_width, device=device) + if is_ones_col + else torch.zeros(num_rows, col_width, device=device) + for col_width, is_ones_col in zip(col_widths, col_mask) + ] + return torch.cat(mask_block, dim=1) + + +class EmformerAttention(nn.Module): + r"""Emformer layer attention module. + + Args: + embed_dim (int): + Embedding dimension. + nhead (int): + Number of attention heads in each Emformer layer. + dropout (float, optional): + Dropout probability. (Default: 0.0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + embed_dim: int, + nhead: int, + dropout: float = 0.0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + if embed_dim % nhead != 0: + raise ValueError( + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." + ) + + self.embed_dim = embed_dim + self.nhead = nhead + self.dropout = dropout + self.tanh_on_mem = tanh_on_mem + self.negative_inf = negative_inf + + self.scaling = (self.embed_dim // self.nhead) ** -0.5 + + self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.emb_to_key_value.weight) + nn.init.constant_(self.emb_to_key_value.bias, 0.0) + + nn.init.xavier_uniform_(self.emb_to_query.weight) + nn.init.constant_(self.emb_to_query.bias, 0.0) + + nn.init.xavier_uniform_(self.out_proj.weight) + nn.init.constant_(self.out_proj.bias, 0.0) + + def _gen_attention_probs( + self, + attention_weights: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor], + ) -> torch.Tensor: + """Given the entire attention weights, mask out unecessary connections + and optionally with padding positions, to obtain underlying chunk-wise + attention probabilities. + + B: batch size; + Q: length of query; + KV: length of key and value. + + Args: + attention_weights (torch.Tensor): + Attention weights computed on the entire concatenated tensor + with shape (B * nhead, Q, KV). + attention_mask (torch.Tensor): + Mask tensor where chunk-wise connections are filled with `False`, + and other unnecessary connections are filled with `True`, + with shape (Q, KV). + padding_mask (torch.Tensor, optional): + Mask tensor where the padding positions are fill with `True`, + and other positions are filled with `False`, with shapa `(B, KV)`. + + Returns: + A tensor of shape (B * nhead, Q, KV). + """ + attention_weights_float = attention_weights.float() + attention_weights_float = attention_weights_float.masked_fill( + attention_mask.unsqueeze(0), self.negative_inf + ) + if padding_mask is not None: + Q = attention_weights.size(1) + B = attention_weights.size(0) // self.nhead + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + self.negative_inf, + ) + attention_weights_float = attention_weights_float.view( + B * self.nhead, Q, -1 + ) + + attention_probs = nn.functional.softmax( + attention_weights_float, dim=-1 + ).type_as(attention_weights) + attention_probs = nn.functional.dropout( + attention_probs, p=float(self.dropout), training=self.training + ) + return attention_probs + + def _forward_impl( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + left_context_key: Optional[torch.Tensor] = None, + left_context_val: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Underlying chunk-wise attention implementation. + + L: length of left_context; + S: length of summary; + M: length of memory; + Q: length of attention query; + KV: length of attention key and value. + + 1) Concat right_context, utterance, summary, + and compute query tensor with length Q = R + U + S. + 2) Concat memory, right_context, utterance, + and compute key, value tensors with length KV = M + R + U; + optionally with left_context_key and left_context_val (inference mode), + then KV = M + R + L + U. + 3) Compute entire attention scores with query, key, and value, + then apply attention_mask to get underlying chunk-wise attention scores. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + summary (torch.Tensor): + Summary elements, with shape (S, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying attention, with shape (Q, KV). + left_context_key (torch,Tensor, optional): + Cached attention key of left context from preceding computation, + with shape (L, B, D). + left_context_val (torch.Tensor, optional): + Cached attention value of left context from preceding computation, + with shape (L, B, D). + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (S, B, D). + - attention key, with shape (KV, B, D). + - attention value, with shape (KV, B, D). + """ + B = utterance.size(1) + + # Compute query with [right context, utterance, summary]. + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) + # Compute key and value with [mems, right context, utterance]. + key, value = self.emb_to_key_value( + torch.cat([memory, right_context, utterance]) + ).chunk(chunks=2, dim=2) + + if left_context_key is not None and left_context_val is not None: + # This is for inference mode. Now compute key and value with + # [mems, right context, left context, uttrance] + M = memory.size(0) + R = right_context.size(0) + right_context_end_idx = M + R + key = torch.cat( + [ + key[:right_context_end_idx], + left_context_key, + key[right_context_end_idx:], + ] + ) + value = torch.cat( + [ + value[:right_context_end_idx], + left_context_val, + value[right_context_end_idx:], + ] + ) + + # Compute attention weights from query, key, and value. + reshaped_query, reshaped_key, reshaped_value = [ + tensor.contiguous() + .view(-1, B * self.nhead, self.embed_dim // self.nhead) + .transpose(0, 1) + for tensor in [query, key, value] + ] + attention_weights = torch.bmm( + reshaped_query * self.scaling, reshaped_key.transpose(1, 2) + ) + + # Compute padding mask + if B == 1: + padding_mask = None + else: + KV = key.size(0) + U = utterance.size(0) + padding_mask = make_pad_mask(KV - U + lengths) + + # Compute attention probabilities. + attention_probs = self._gen_attention_probs( + attention_weights, attention_mask, padding_mask + ) + + # Compute attention. + attention = torch.bmm(attention_probs, reshaped_value) + Q = query.size(0) + assert attention.shape == ( + B * self.nhead, + Q, + self.embed_dim // self.nhead, + ) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) + + # Apply output projection. + outputs = self.out_proj(attention) + + S = summary.size(0) + summary_start_idx = Q - S + output_right_context_utterance = outputs[:summary_start_idx] + output_memory = outputs[summary_start_idx:] + if self.tanh_on_mem: + output_memory = torch.tanh(output_memory) + else: + output_memory = torch.clamp(output_memory, min=-10, max=10) + + return output_right_context_utterance, output_memory, key, value + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO: Modify docs. + """Forward pass for training. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + S: length of summary; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + summary (torch.Tensor): + Summary elements, with shape (S, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying chunk-wise attention, + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + + Returns: + A tuple containing 2 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (M, B, D), where M = S - 1 or M = 0. + """ + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( + utterance, lengths, right_context, summary, memory, attention_mask + ) + return output_right_context_utterance, output_memory[:-1] + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + summary: torch.Tensor, + memory: torch.Tensor, + left_context_key: torch.Tensor, + left_context_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + L: length of left_context; + S: length of summary; + M: length of memory; + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + summary (torch.Tensor): + Summary element, with shape (1, B, D), or empty. + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + left_context_key (torch,Tensor): + Cached attention key of left context from preceding computation, + with shape (L, B, D). + left_context_val (torch.Tensor): + Cached attention value of left context from preceding computation, + with shape (L, B, D). + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - memory output, with shape (1, B, D) or (0, B, D). + - attention key of left context and utterance, which would be cached + for next computation, with shape (L + U, B, D). + - attention value of left context and utterance, which would be + cached for next computation, with shape (L + U, B, D). + """ + # query: [right context, utterance, summary] + Q = right_context.size(0) + utterance.size(0) + summary.size(0) + # key, value: [memory, right context, left context, uttrance] + KV = ( + memory.size(0) + + right_context.size(0) + + left_context_key.size(0) + + utterance.size(0) + ) + attention_mask = torch.zeros(Q, KV).to( + dtype=torch.bool, device=utterance.device + ) + # Disallow attention bettween the summary vector with the memory bank + attention_mask[-1, : memory.size(0)] = True + ( + output_right_context_utterance, + output_memory, + key, + value, + ) = self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + right_context_end_idx = memory.size(0) + right_context.size(0) + return ( + output_right_context_utterance, + output_memory, + key[right_context_end_idx:], + value[right_context_end_idx:], + ) + + +class EmformerLayer(nn.Module): + """Emformer layer that constitutes Emformer. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads. + dim_feedforward (int): + Hidden layer dimension of feedforward network. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (Default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (Default: 0) + max_memory_size (int, optional): + Maximum number of memory elements to use. (Default: 0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int, + chunk_length: int, + dropout: float = 0.0, + cnn_module_kernel: int = 3, + left_context_length: int = 0, + max_memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.attention = EmformerAttention( + embed_dim=d_model, + nhead=nhead, + dropout=0.0, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + self.summary_op = nn.AvgPool1d( + kernel_size=chunk_length, stride=chunk_length, ceil_mode=True + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm(d_model) + self.norm_ff = nn.LayerNorm(d_model) + self.norm_mha = nn.LayerNorm(d_model) + self.norm_conv = nn.LayerNorm(d_model) + self.norm_final = nn.LayerNorm(d_model) + + self.dropout = nn.Dropout(dropout) + + self.ff_scale = 0.5 + self.left_context_length = left_context_length + self.chunk_length = chunk_length + self.max_memory_size = max_memory_size + self.d_model = d_model + self.use_memory = max_memory_size > 0 + + def _init_state( + self, batch_size: int, device: Optional[torch.device] + ) -> List[torch.Tensor]: + """Initialize states with zeros.""" + empty_memory = torch.zeros( + self.max_memory_size, batch_size, self.d_model, device=device + ) + left_context_key = torch.zeros( + self.left_context_length, batch_size, self.d_model, device=device + ) + left_context_val = torch.zeros( + self.left_context_length, batch_size, self.d_model, device=device + ) + past_length = torch.zeros( + 1, batch_size, dtype=torch.int32, device=device + ) + return [empty_memory, left_context_key, left_context_val, past_length] + + def _unpack_state( + self, state: List[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Unpack cached states including: + 1) output memory from previous chunks in the lower layer; + 2) attention key and value of left context from proceeding chunk's + computation. + """ + past_length = state[3][0][0].item() + past_left_context_length = min(self.left_context_length, past_length) + past_memory_length = min( + self.max_memory_size, math.ceil(past_length / self.chunk_length) + ) + memory_start_idx = self.max_memory_size - past_memory_length + pre_memory = state[0][memory_start_idx:] + left_context_start_idx = ( + self.left_context_length - past_left_context_length + ) + left_context_key = state[1][left_context_start_idx:] + left_context_val = state[2][left_context_start_idx:] + return pre_memory, left_context_key, left_context_val + + def _pack_state( + self, + next_key: torch.Tensor, + next_val: torch.Tensor, + update_length: int, + memory: torch.Tensor, + state: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Pack updated states including: + 1) output memory of current chunk in the lower layer; + 2) attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + 3) length of current chunk. + """ + new_memory = torch.cat([state[0], memory]) + new_key = torch.cat([state[1], next_key]) + new_val = torch.cat([state[2], next_val]) + memory_start_idx = new_memory.size(0) - self.max_memory_size + state[0] = new_memory[memory_start_idx:] + key_start_idx = new_key.size(0) - self.left_context_length + state[1] = new_key[key_start_idx:] + val_start_idx = new_val.size(0) - self.left_context_length + state[2] = new_val[val_start_idx:] + state[3] = state[3] + update_length + return state + + def _apply_macaron_feed_foward_module( + self, right_context_utterance: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply macaron style feed forward module.""" + residual = right_context_utterance + right_context_utterance = self.norm_ff_macaron(right_context_utterance) + right_context_utterance = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(right_context_utterance) + ) + return right_context_utterance + + def _apply_conv_module( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + ) -> torch.Tensor: + """Apply convolution module on utterance.""" + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + residual = utterance + utterance = self.norm_conv(utterance) + utterance = residual + self.dropout(self.conv_module(utterance)) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance + + def _apply_feed_forward_module( + self, + right_context_utterance: torch.Tensor, + ) -> torch.Tensor: + """Apply feed forward module.""" + residual = right_context_utterance + right_context_utterance = self.norm_ff(right_context_utterance) + right_context_utterance = residual + self.ff_scale * self.dropout( + self.feed_forward(right_context_utterance) + ) + return right_context_utterance + + def _apply_attention_module_forward( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + lengths: torch.Tensor, + memory: torch.Tensor, + attention_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply attention module in non-infer mode.""" + if attention_mask is None: + raise ValueError( + "attention_mask must be not None in non-infer mode. " + ) + + residual = right_context_utterance + right_context_utterance = self.norm_mha(right_context_utterance) + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + if self.use_memory: + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) + else: + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + output_right_context_utterance, output_memory = self.attention( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=memory, + attention_mask=attention_mask, + ) + right_context_utterance = residual + self.dropout( + output_right_context_utterance + ) + + return right_context_utterance, output_memory + + def _apply_attention_module_infer( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + lengths: torch.Tensor, + memory: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Apply attention in infer mode. + 1) Unpack cached states including: + - memory from previous chunks in the lower layer; + - attention key and value of left context from proceeding + chunk's compuation; + 2) Apply attention computation; + 3) Pack updated states including: + - output memory of current chunk in the lower layer; + - attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + - length of current chunk. + """ + residual = right_context_utterance + right_context_utterance = self.norm_mha(right_context_utterance) + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + if state is None: + state = self._init_state(utterance.size(1), device=utterance.device) + pre_memory, left_context_key, left_context_val = self._unpack_state( + state + ) + if self.use_memory: + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) + summary = summary[:1] + else: + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = self.attention.infer( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + right_context_utterance = residual + self.dropout( + output_right_context_utterance + ) + state = self._pack_state( + next_key, next_val, utterance.size(0), memory, state + ) + return right_context_utterance, output_memory, state + + def forward( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Forward pass for training. + 1) Apply layer normalization on input utterance and right context + before attention; + 2) Apply attention module, compute updated utterance, right context, + and memory; + 3) Apply feed forward module and layer normalization on output utterance + and right context. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying attention module, + with shape (Q, KV), where Q = R + U + S, KV = M + R + U. + + Returns: + A tuple containing 3 tensors: + - output utterance, with shape (U, B, D). + - output right context, with shape (R, B, D). + - output memory, with shape (M, B, D). + """ + right_context_utterance = torch.cat([right_context, utterance]) + right_context_end_idx = right_context.size(0) + + right_context_utterance = self._apply_macaron_feed_foward_module( + right_context_utterance + ) + + ( + right_context_utterance, + output_memory, + ) = self._apply_attention_module_forward( + right_context_utterance, + right_context_end_idx, + lengths, + memory, + attention_mask, + ) + + right_context_utterance = self._apply_conv_module( + right_context_utterance, right_context_end_idx + ) + + right_context_utterance = self._apply_feed_forward_module( + right_context_utterance + ) + + right_context_utterance = self.norm_final(right_context_utterance) + + output_utterance = right_context_utterance[right_context_end_idx:] + output_right_context = right_context_utterance[:right_context_end_idx] + return output_utterance, output_right_context, output_memory + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + lengths: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + state: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: + """Forward pass for inference. + + 1) Apply layer normalization on input utterance and right context + before attention; + 2) Apply attention module with cached state, compute updated utterance, + right context, and memory, and update state; + 3) Apply feed forward module and layer normalization on output + utterance and right context. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + state (List[torch.Tensor], optional): + List of tensors representing layer internal state generated in + preceding computation. (default=None) + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output memory, with shape (1, B, D) or (0, B, D). + - output state. + """ + right_context_utterance = torch.cat([right_context, utterance]) + right_context_end_idx = right_context.size(0) + + right_context_utterance = self._apply_macaron_feed_foward_module( + right_context_utterance + ) + + ( + right_context_utterance, + output_memory, + output_state, + ) = self._apply_attention_module_infer( + right_context_utterance, + right_context_end_idx, + lengths, + memory, + state, + ) + + right_context_utterance = self._apply_conv_module( + right_context_utterance, right_context_end_idx + ) + + right_context_utterance = self._apply_feed_forward_module( + right_context_utterance + ) + + right_context_utterance = self.norm_final(right_context_utterance) + + output_utterance = right_context_utterance[right_context_end_idx:] + output_right_context = right_context_utterance[:right_context_end_idx] + return ( + output_utterance, + output_right_context, + output_memory, + output_state, + ) + + +class EmformerEncoder(nn.Module): + """Implements the Emformer architecture introduced in + *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency + Streaming Speech Recognition* + [:footcite:`shi2021emformer`]. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads in each emformer layer. + dim_feedforward (int): + Hidden layer dimension of each emformer layer's feedforward network. + num_encoder_layers (int): + Number of emformer layers to instantiate. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (default: 0.0) + left_context_length (int, optional): + Length of left context. (default: 0) + right_context_length (int, optional): + Length of right context. (default: 0) + max_memory_size (int, optional): + Maximum number of memory elements to use. (default: 0) + tanh_on_mem (bool, optional): + If ``true``, applies tanh to memory elements. (default: ``false``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (default: -1e8) + """ + + def __init__( + self, + chunk_length: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + cnn_module_kernel: int = 3, + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.use_memory = max_memory_size > 0 + self.init_memory_op = nn.AvgPool1d( + kernel_size=chunk_length, + stride=chunk_length, + ceil_mode=True, + ) + + self.emformer_layers = nn.ModuleList( + [ + EmformerLayer( + d_model, + nhead, + dim_feedforward, + chunk_length, + dropout=dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length, + max_memory_size=max_memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + for layer_idx in range(num_encoder_layers) + ] + ) + + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.max_memory_size = max_memory_size + + def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: + """Hard copy each chunk's right context and concat them.""" + T = x.shape[0] + num_segs = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) + right_context_blocks = [] + for seg_idx in range(num_segs - 1): + start = (seg_idx + 1) * self.chunk_length + end = start + self.right_context_length + right_context_blocks.append(x[start:end]) + last_right_context_start_idx = T - self.right_context_length + right_context_blocks.append(x[last_right_context_start_idx:]) + return torch.cat(right_context_blocks) + + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: + """Calculate column widths (key, value) in attention mask for the + chunk_idx chunk.""" + num_chunks = math.ceil(U / self.chunk_length) + rc = self.right_context_length + lc = self.left_context_length + rc_start = chunk_idx * rc + rc_end = rc_start + rc + chunk_start = max(chunk_idx * self.chunk_length - lc, 0) + chunk_end = min((chunk_idx + 1) * self.chunk_length, U) + R = rc * num_chunks + + if self.use_memory: + m_start = max(chunk_idx - self.max_memory_size, 0) + M = num_chunks - 1 + col_widths = [ + m_start, # before memory + chunk_idx - m_start, # memory + M - chunk_idx, # after memory + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + else: + col_widths = [ + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + + return col_widths + + def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor: + """Generate attention mask for underlying chunk-wise attention + computation, where chunk-wise connections are filled with `False`, + and other unnecessary connections beyond chunk are filled with `True`. + + R: length of right_context; + U: length of utterance; + S: length of summary; + M: length of memory; + Q: length of attention query; + KV: length of attention key and value. + + The shape of attention mask is (Q, KV). + If self.use_memory is `True`: + query = [right_context, utterance, summary]; + key, value = [memory, right_context, utterance]; + Q = R + U + S, KV = M + R + U. + Otherwise: + query = [right_context, utterance] + key, value = [right_context, utterance] + Q = R + U, KV = R + U. + """ + U = utterance.size(0) + num_chunks = math.ceil(U / self.chunk_length) + + right_context_mask = [] + utterance_mask = [] + summary_mask = [] + + if self.use_memory: + num_cols = 9 + # right context and utterance both attend to memory, right context, + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4, 7] for idx in range(num_cols) + ] + # summary attends to right context, utterance + summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] + masks_to_concat = [right_context_mask, utterance_mask, summary_mask] + else: + num_cols = 6 + # right context and utterance both attend to right context and + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4] for idx in range(num_cols) + ] + summary_cols_mask = None + masks_to_concat = [right_context_mask, utterance_mask] + + for chunk_idx in range(num_chunks): + col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) + + right_context_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + self.right_context_length, + utterance.device, + ) + right_context_mask.append(right_context_mask_block) + + utterance_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + min( + self.chunk_length, + U - chunk_idx * self.chunk_length, + ), + utterance.device, + ) + utterance_mask.append(utterance_mask_block) + + if summary_cols_mask is not None: + summary_mask_block = _gen_attention_mask_block( + col_widths, summary_cols_mask, 1, utterance.device + ) + summary_mask.append(summary_mask_block) + + attention_mask = ( + 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) + ).to(torch.bool) + return attention_mask + + def forward( + self, x: torch.Tensor, lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + + Returns: + A tuple of 2 tensors: + - output utterance frames, with shape (U, B, D). + - output_lengths, with shape (B,), without containing the + right_context at the end. + """ + right_context = self._gen_right_context(x) + utterance = x[: x.size(0) - self.right_context_length] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + attention_mask = self._gen_attention_mask(utterance) + memory = ( + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if self.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + output = utterance + for layer in self.emformer_layers: + output, right_context, memory = layer( + output, output_lengths, right_context, memory, attention_mask + ) + + return output, output_lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + lengths: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + """Forward pass for streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + states (List[List[torch.Tensor]], optional): + Cached states from proceeding chunk's computation, where each + element (List[torch.Tensor]) corresponding to each emformer layer. + (default: None) + + Returns: + (Tensor, Tensor, List[List[torch.Tensor]]): + - output utterance frames, with shape (U, B, D). + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + """ + assert x.size(0) == self.chunk_length + self.right_context_length, ( + "Per configured chunk_length and right_context_length, " + f"expected size of {self.chunk_length + self.right_context_length} " + f"for dimension 1 of x, but got {x.size(1)}." + ) + right_context_start_idx = x.size(0) - self.right_context_length + right_context = x[right_context_start_idx:] + utterance = x[:right_context_start_idx] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + memory = ( + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + if self.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + output = utterance + output_states: List[List[torch.Tensor]] = [] + for layer_idx, layer in enumerate(self.emformer_layers): + output, right_context, memory, output_state = layer.infer( + output, + output_lengths, + right_context, + memory, + None if states is None else states[layer_idx], + ) + output_states.append(output_state) + + return output, output_lengths, output_states + + +class Emformer(EncoderInterface): + def __init__( + self, + num_features: int, + output_dim: int, + chunk_length: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + cnn_module_kernel: int = 3, + vgg_frontend: bool = False, + left_context_length: int = 0, + right_context_length: int = 0, + max_memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.subsampling_factor = subsampling_factor + self.right_context_length = right_context_length + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + if chunk_length % 4 != 0: + raise NotImplementedError("chunk_length must be a mutiple of 4.") + if left_context_length != 0 and left_context_length % 4 != 0: + raise NotImplementedError( + "left_context_length must be 0 or a mutiple of 4." + ) + if right_context_length != 0 and right_context_length % 4 != 0: + raise NotImplementedError( + "right_context_length must be 0 or a mutiple of 4." + ) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder = EmformerEncoder( + chunk_length // 4, + d_model, + nhead, + dim_feedforward, + num_encoder_layers, + dropout, + cnn_module_kernel, + left_context_length=left_context_length // 4, + right_context_length=right_context_length // 4, + max_memory_size=max_memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) + ) + + def forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + x_lens (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + + Returns: + (Tensor, Tensor): + - output logits, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - logits lengths, with shape (B,), without containing the + right_context at the end. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == x_lens.max().item() + + output, output_lengths = self.encoder(x, x_lens) # (T, N, C) + + logits = self.encoder_output_layer(output) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, output_lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[List[List[torch.Tensor]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + """Forward pass for streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + states (List[List[torch.Tensor]], optional): + Cached states from proceeding chunk's computation, where each + element (List[torch.Tensor]) corresponding to each emformer layer. + (default: None) + Returns: + (Tensor, Tensor): + - output logits, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - logits lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == x_lens.max().item() + + output, output_lengths, output_states = self.encoder.infer( + x, x_lens, states + ) # (T, N, C) + + logits = self.encoder_output_layer(output) + logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return logits, output_lengths, output_states + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/encoder_interface.py b/egs/librispeech/ASR/conv_emformer_transducer/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/joiner.py b/egs/librispeech/ASR/conv_emformer_transducer/joiner.py new file mode 120000 index 000000000..81ad47c55 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/model.py b/egs/librispeech/ASR/conv_emformer_transducer/model.py new file mode 120000 index 000000000..a61a0a23f --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/model.py @@ -0,0 +1 @@ +../pruned_transducer_stateless/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/noam.py b/egs/librispeech/ASR/conv_emformer_transducer/noam.py new file mode 100644 index 000000000..e46bf35fb --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/noam.py @@ -0,0 +1,104 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/subsampling.py b/egs/librispeech/ASR/conv_emformer_transducer/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py new file mode 100644 index 000000000..1f735637f --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -0,0 +1,359 @@ +import torch + + +def test_emformer_attention_forward(): + from emformer import EmformerAttention + + B, D = 2, 256 + U, R = 12, 2 + chunk_length = 2 + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + + output_right_context_utterance, output_memory = attention( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_attention_infer(): + from emformer import EmformerAttention + + B, D = 2, 256 + R, L = 4, 2 + chunk_length = 2 + U = chunk_length + attention = EmformerAttention(embed_dim=D, nhead=8) + + for use_memory in [True, False]: + if use_memory: + S, M = 1, 3 + else: + S, M = 0, 0 + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + summary = torch.randn(S, B, D) + memory = torch.randn(M, B, D) + left_context_key = torch.randn(L, B, D) + left_context_val = torch.randn(L, B, D) + + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = attention.infer( + utterance, + lengths, + right_context, + summary, + memory, + left_context_key, + left_context_val, + ) + assert output_right_context_utterance.shape == (R + U, B, D) + assert output_memory.shape == (S, B, D) + assert next_key.shape == (L + U, B, D) + assert next_val.shape == (L + U, B, D) + + +def test_emformer_layer_forward(): + from emformer import EmformerLayer + + B, D = 2, 256 + U, R, L = 12, 2, 5 + chunk_length = 2 + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + layer = EmformerLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + cnn_module_kernel=3, + left_context_length=L, + max_memory_size=M, + ) + + Q, KV = R + U + S, M + R + U + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + attention_mask = torch.rand(Q, KV) >= 0.5 + + output_utterance, output_right_context, output_memory = layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + assert output_memory.shape == (M, B, D) + + +def test_emformer_layer_infer(): + from emformer import EmformerLayer + + B, D = 2, 256 + R, L = 2, 5 + chunk_length = 2 + U = chunk_length + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + layer = EmformerLayer( + d_model=D, + nhead=8, + dim_feedforward=1024, + chunk_length=chunk_length, + cnn_module_kernel=3, + left_context_length=L, + max_memory_size=M, + ) + + utterance = torch.randn(U, B, D) + lengths = torch.randint(1, U + 1, (B,)) + lengths[0] = U + right_context = torch.randn(R, B, D) + memory = torch.randn(M, B, D) + state = None + ( + output_utterance, + output_right_context, + output_memory, + output_state, + ) = layer.infer( + utterance, + lengths, + right_context, + memory, + state, + ) + assert output_utterance.shape == (U, B, D) + assert output_right_context.shape == (R, B, D) + if use_memory: + assert output_memory.shape == (1, B, D) + else: + assert output_memory.shape == (0, B, D) + assert len(output_state) == 4 + assert output_state[0].shape == (M, B, D) + assert output_state[1].shape == (L, B, D) + assert output_state[2].shape == (L, B, D) + assert output_state[3].shape == (1, B) + + +def test_emformer_encoder_forward(): + from emformer import EmformerEncoder + + B, D = 2, 256 + U, R, L = 12, 2, 5 + chunk_length = 2 + + for use_memory in [True, False]: + if use_memory: + S = U // chunk_length + M = S - 1 + else: + S, M = 0, 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=2, + cnn_module_kernel=3, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + ) + + x = torch.randn(U + R, B, D) + lengths = torch.randint(1, U + R + 1, (B,)) + lengths[0] = U + R + + output, output_lengths = encoder(x, lengths) + assert output.shape == (U, B, D) + assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) + + +def test_emformer_encoder_infer(): + from emformer import EmformerEncoder + + B, D = 2, 256 + R, L = 2, 5 + chunk_length = 2 + U = chunk_length + num_chunks = 3 + num_encoder_layers = 2 + + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=3, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + ) + + states = None + for chunk_idx in range(num_chunks): + x = torch.randn(U + R, B, D) + lengths = torch.randint(1, U + R + 1, (B,)) + lengths[0] = U + R + output, output_lengths, states = encoder.infer(x, lengths, states) + assert output.shape == (U, B, D) + assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (L, B, D) + assert state[2].shape == (L, B, D) + assert torch.equal( + state[3], (chunk_idx + 1) * U * torch.ones_like(state[3]) + ) + + +def test_emformer_forward(): + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + L, R = 128, 4 + B, D, U = 2, 256, 80 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + cnn_module_kernel=3, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + x = torch.randn(B, U + R + 3, num_features) + x_lens = torch.randint(1, U + R + 3 + 1, (B,)) + x_lens[0] = U + R + 3 + logits, output_lengths = model(x, x_lens) + assert logits.shape == (B, U // 4, output_dim) + assert torch.equal( + output_lengths, + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0), + ) + + +def test_emformer_infer(): + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + U = chunk_length + L, R = 128, 4 + B, D = 2, 256 + num_chunks = 3 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=3, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + states = None + for chunk_idx in range(num_chunks): + x = torch.randn(B, U + R + 3, num_features) + x_lens = torch.randint(1, U + R + 3 + 1, (B,)) + x_lens[0] = U + R + 3 + logits, output_lengths, states = model.infer(x, x_lens, states) + assert logits.shape == (B, U // 4, output_dim) + assert torch.equal( + output_lengths, + torch.clamp(((x_lens - 1) // 2 - 1) // 2 - R // 4, min=0), + ) + assert len(states) == num_encoder_layers + for state in states: + assert len(state) == 4 + assert state[0].shape == (M, B, D) + assert state[1].shape == (L // 4, B, D) + assert state[2].shape == (L // 4, B, D) + assert torch.equal( + state[3], + U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]), + ) + + +if __name__ == "__main__": + test_emformer_attention_forward() + test_emformer_attention_infer() + test_emformer_layer_forward() + test_emformer_layer_infer() + test_emformer_encoder_forward() + test_emformer_encoder_infer() + test_emformer_forward() + test_emformer_infer() diff --git a/egs/librispeech/ASR/conv_emformer_transducer/train.py b/egs/librispeech/ASR/conv_emformer_transducer/train.py new file mode 100755 index 000000000..bdb541ac6 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer/train.py @@ -0,0 +1,1006 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./transducer_emformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir transducer_emformer/exp \ + --full-libri 1 \ + --max-duration 300 +""" + + +import argparse +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from emformer import Emformer +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from noam import Noam +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + setup_logger, + str2bool, +) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--attention-dim", + type=int, + default=512, + help="Attention dim for the Emformer", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads for the Emformer", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feed-forward dimension for the Emformer", + ) + + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of encoder layers for the Emformer", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=int, + default=3, + help="Kernel size for the convolution module.", + ) + + parser.add_argument( + "--left-context-length", + type=int, + default=120, + help="Number of frames for the left context in the Emformer", + ) + + parser.add_argument( + "--chunk-length", + type=int, + default=16, + help="Number of frames for each segment in the Emformer", + ) + + parser.add_argument( + "--right-context-length", + type=int, + default=4, + help="Number of frames for right context in the Emformer", + ) + + parser.add_argument( + "--memory-size", + type=int, + default=0, + help="Number of entries in the memory for the Emformer", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_emformer/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - attention_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + "log_diagnostics": False, + # parameters for Emformer + "feature_dim": 80, + "subsampling_factor": 4, + "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 512, + # parameters for Noam + "warm_step": 80000, # For the 100h subset, use 20000 + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Emformer( + num_features=params.feature_dim, + output_dim=params.vocab_size, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + cnn_module_kernel=params.cnn_module_kernel, + vgg_frontend=params.vgg_frontend, + left_context_length=params.left_context_length, + chunk_length=params.chunk_length, + right_context_length=params.right_context_length, + max_memory_size=params.memory_size, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + sampler=sampler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Emformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + def maybe_log_gradients(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_weights(tag: str): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + tb_writer.add_scalars( + tag, + measure_weight_norms(model, norm="l2"), + global_step=params.batch_idx_train, + ) + + def maybe_log_param_relative_changes(): + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + deltas = optim_step_and_measure_param_change(model, optimizer) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) + else: + optimizer.step() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + loss.backward() + + maybe_log_weights("train/param_norms") + maybe_log_gradients("train/grad_norms") + maybe_log_param_relative_changes() + + optimizer.zero_grad() + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + params.warm_step = 20000 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + sampler=train_dl.sampler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From a24eef8096d77646207f4e9dec0ac401e27db604 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 10 Apr 2022 20:29:22 +0800 Subject: [PATCH 192/234] update conv_emformer_transducer/emformer.py. --- egs/librispeech/ASR/conv_emformer_transducer/emformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index 49e59bd00..c55a73d68 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -92,6 +92,8 @@ class EmformerAttention(nn.Module): self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self._reset_parameters() + def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.emb_to_key_value.weight) nn.init.constant_(self.emb_to_key_value.bias, 0.0) From 08473a17aa984a77a38d213ab05a94bac643bc38 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sun, 10 Apr 2022 23:29:28 +0800 Subject: [PATCH 193/234] Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. --- .pre-commit-config.yaml | 2 ++ icefall/__init__.py | 10 ++++++++++ pyproject.toml | 1 + 3 files changed, 13 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b59784dbf..446ba0fe7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,8 @@ repos: hooks: - id: black args: [--line-length=80] + additional_dependencies: ['click==8.0.1'] + exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 diff --git a/icefall/__init__.py b/icefall/__init__.py index f466d6a62..ec77e89b5 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,3 +1,13 @@ +# isort:skip_file + +from . import ( + checkpoint, + decode, + dist, + env, + utils +) + from .checkpoint import ( average_checkpoints, find_checkpoints, diff --git a/pyproject.toml b/pyproject.toml index ec5623f90..b4f8c3377 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,4 +10,5 @@ exclude = ''' | \.github )/ | make_kn_lm.py + | icefall\/__init__\.py ''' From f721a2fd7aef94bf91bf59529239b7d35700e2b4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 10 Apr 2022 23:34:18 +0800 Subject: [PATCH 194/234] Minor fixes for logging (#296) * Minor fixes for logging * Minor fix --- .../ASR/pruned_transducer_stateless/train.py | 36 ++++++++++-------- icefall/utils.py | 37 +++++++++++++------ 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 17f82e601..e743106ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -609,21 +609,6 @@ def train_one_epoch( global_step=params.batch_idx_train, ) - def maybe_log_param_relative_changes(): - if ( - params.log_diagnostics - and tb_writer is not None - and params.batch_idx_train % (params.log_interval * 5) == 0 - ): - deltas = optim_step_and_measure_param_change(model, optimizer) - tb_writer.add_scalars( - "train/relative_param_change_per_minibatch", - deltas, - global_step=params.batch_idx_train, - ) - else: - optimizer.step() - cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): @@ -651,7 +636,26 @@ def train_one_epoch( maybe_log_weights("train/param_norms") maybe_log_gradients("train/grad_norms") - maybe_log_param_relative_changes() + + old_parameters = None + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + + optimizer.step() + + if old_parameters is not None: + deltas = optim_step_and_measure_param_change(model, old_parameters) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) optimizer.zero_grad() diff --git a/icefall/utils.py b/icefall/utils.py index c231dbbe4..daccd4346 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -25,15 +25,14 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union +from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 import k2.version import kaldialign import torch -import torch.nn as nn import torch.distributed as dist -from torch.cuda.amp import GradScaler +import torch.nn as nn from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -758,11 +757,10 @@ def measure_gradient_norms( def optim_step_and_measure_param_change( model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: Optional[GradScaler] = None, + old_parameters: Dict[str, nn.parameter.Parameter], ) -> Dict[str, float]: """ - Perform model weight update and measure the "relative change in parameters per minibatch." + Measure the "relative change in parameters per minibatch." It is understood as a ratio between the L2 norm of the difference between original and updates parameters, and the L2 norm of the original parameter. It is given by the formula: @@ -770,16 +768,31 @@ def optim_step_and_measure_param_change( \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned} - """ - param_copy = {n: p.detach().clone() for n, p in model.named_parameters()} - if scaler: - scaler.step(optimizer) - else: + + This function is supposed to be used as follows: + + .. code-block:: python + + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + optimizer.step() + + deltas = optim_step_and_measure_param_change(old_parameters) + + Args: + model: A torch.nn.Module instance. + old_parameters: + A Dict of named_parameters before optimizer.step(). + + Return: + A Dict containing the relative change for each parameter. + """ relative_change = {} with torch.no_grad(): for n, p_new in model.named_parameters(): - p_orig = param_copy[n] + p_orig = old_parameters[n] delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() return relative_change From 46d52dda1080baad3f3468a2040eed962c7e73fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 12:03:41 +0800 Subject: [PATCH 195/234] Fix dir names --- .../ASR/pruned_transducer_stateless2/decode.py | 18 +++++++++--------- .../ASR/pruned_transducer_stateless2/export.py | 12 ++++++------ .../ASR/pruned_transducer_stateless2/train.py | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 8e924bf96..38aff8834 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless/decode.py \ +./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless/exp \ + --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -124,7 +124,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 7d2a07817..b5757ee8c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -20,23 +20,23 @@ # to a single one using model averaging. """ Usage: -./pruned_transducer_stateless/export.py \ - --exp-dir ./pruned_transducer_stateless/exp \ +./pruned_transducer_stateless2/export.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10 It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless/decode.py`, +To use the generated file with `pruned_transducer_stateless2/decode.py`, you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless/decode.py \ - --exp-dir ./pruned_transducer_stateless/exp \ + ./pruned_transducer_stateless2/decode.py \ + --exp-dir ./pruned_transducer_stateless2/exp \ --epoch 9999 \ --avg 1 \ --max-duration 100 \ @@ -80,7 +80,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless/exp", + default="pruned_transducer_stateless2/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 62dc825b6..c24fbe9a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -116,7 +116,7 @@ def get_parser(): default=0, help="""Resume training from from this epoch. If it is positive, it will load checkpoint from - transducer_stateless/exp/epoch-{start_epoch-1}.pt + transducer_stateless2/exp/epoch-{start_epoch-1}.pt """, ) From 1d74c5e59686e3248b9d3a8e911dd4d7e60d30eb Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 11 Apr 2022 12:28:15 +0800 Subject: [PATCH 196/234] Support causal convolution in emformer encoder layer. --- .../ASR/conv_emformer_transducer/emformer.py | 46 ++++++++++++++++--- .../conv_emformer_transducer/test_emformer.py | 6 +++ .../ASR/conv_emformer_transducer/train.py | 8 ++++ 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index c55a73d68..032ecb77d 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -458,6 +458,8 @@ class EmformerLayer(nn.Module): If ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): Value to use for negative infinity in attention weights. (Default: -1e8) + causal (bool): + Whether use causal convolution (default=False). """ def __init__( @@ -472,6 +474,7 @@ class EmformerLayer(nn.Module): max_memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + causal: bool = False, ): super().__init__() @@ -500,7 +503,11 @@ class EmformerLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = ConvolutionModule( + d_model, + cnn_module_kernel, + causal=causal, + ) self.norm_ff_macaron = nn.LayerNorm(d_model) self.norm_ff = nn.LayerNorm(d_model) @@ -910,6 +917,8 @@ class EmformerEncoder(nn.Module): If ``true``, applies tanh to memory elements. (default: ``false``) negative_inf (float, optional): Value to use for negative infinity in attention weights. (default: -1e8) + causal (bool): + Whether use causal convolution (default=False). """ def __init__( @@ -926,6 +935,7 @@ class EmformerEncoder(nn.Module): max_memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + causal: bool = False, ): super().__init__() @@ -949,6 +959,7 @@ class EmformerEncoder(nn.Module): max_memory_size=max_memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + causal=causal, ) for layer_idx in range(num_encoder_layers) ] @@ -1220,6 +1231,7 @@ class Emformer(EncoderInterface): max_memory_size: int = 0, tanh_on_mem: bool = False, negative_inf: float = -1e8, + causal: bool = False, ): super().__init__() @@ -1261,6 +1273,7 @@ class Emformer(EncoderInterface): max_memory_size=max_memory_size, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, + causal=causal, ) # TODO(fangjun): remove dropout @@ -1366,14 +1379,22 @@ class ConvolutionModule(nn.Module): Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - + channels (int): + The number of channels of conv layers. + kernel_size (int): + Kernerl size of conv layers. + bias (bool): + Whether to use bias in conv layers (default=True). + causal (bool): + Whether use causal convolution (default=False). """ def __init__( - self, channels: int, kernel_size: int, bias: bool = True + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -1388,12 +1409,19 @@ class ConvolutionModule(nn.Module): padding=0, bias=bias, ) + + if causal: + self.left_padding = kernel_size - 1 + padding = 0 + else: + self.left_padding = 0 + padding = (kernel_size - 1) // 2 self.depthwise_conv = nn.Conv1d( channels, channels, kernel_size, stride=1, - padding=(kernel_size - 1) // 2, + padding=padding, groups=channels, bias=bias, ) @@ -1426,6 +1454,10 @@ class ConvolutionModule(nn.Module): x = nn.functional.glu(x, dim=1) # (batch, channels, time) # 1D Depthwise Conv + if self.left_padding > 0: + # manualy padding self.lorder zeros to the left + # make depthwise_conv causal + x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0) x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py index 1f735637f..7685bfb26 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -103,6 +103,7 @@ def test_emformer_layer_forward(): cnn_module_kernel=3, left_context_length=L, max_memory_size=M, + causal=True, ) Q, KV = R + U + S, M + R + U @@ -147,6 +148,7 @@ def test_emformer_layer_infer(): cnn_module_kernel=3, left_context_length=L, max_memory_size=M, + causal=True, ) utterance = torch.randn(U, B, D) @@ -203,6 +205,7 @@ def test_emformer_encoder_forward(): left_context_length=L, right_context_length=R, max_memory_size=M, + causal=True, ) x = torch.randn(U + R, B, D) @@ -239,6 +242,7 @@ def test_emformer_encoder_infer(): left_context_length=L, right_context_length=R, max_memory_size=M, + causal=True, ) states = None @@ -284,6 +288,7 @@ def test_emformer_forward(): right_context_length=R, max_memory_size=M, vgg_frontend=False, + causal=True, ) x = torch.randn(B, U + R + 3, num_features) x_lens = torch.randint(1, U + R + 3 + 1, (B,)) @@ -324,6 +329,7 @@ def test_emformer_infer(): right_context_length=R, max_memory_size=M, vgg_frontend=False, + causal=True, ) states = None for chunk_idx in range(num_chunks): diff --git a/egs/librispeech/ASR/conv_emformer_transducer/train.py b/egs/librispeech/ASR/conv_emformer_transducer/train.py index bdb541ac6..d0126bb94 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/train.py @@ -137,6 +137,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Number of entries in the memory for the Emformer", ) + parser.add_argument( + "--causal-conv", + type=bool, + default=True, + help="Whether use causal convolution.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -377,6 +384,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: chunk_length=params.chunk_length, right_context_length=params.right_context_length, max_memory_size=params.memory_size, + causal=params.causal_conv, ) return encoder From d5f9d49e536d938b4ccc64bccb1a63bda4ea88fd Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 12:35:29 +0800 Subject: [PATCH 197/234] Modify beam search to be efficient with current joienr --- .../beam_search.py | 766 +++++++++++++++++- 1 file changed, 765 insertions(+), 1 deletion(-) mode change 120000 => 100644 egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 120000 index 227d2247c..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py new file mode 100644 index 000000000..5876d5158 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -0,0 +1,765 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import k2 +import torch +from model import Transducer + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + + +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner(current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False) + # logits is (1, 1, 1, vocab_size) + + y = logits.argmax().item() + if y != blank_id: + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + return hyp + + +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + encoder_out = model.joiner.encoder_proj(encoder_out) + + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1), + project_input=False) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + ans = [h[context_size:] for h in hyps] + return ans + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, + decoder_out.unsqueeze(1), + project_input=False + ) + + # TODO(fangjun): Scale the blank posterior + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id: + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + return ys From 6c1f9b5181e23a69b122596359d574d28ebcf440 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 11 Apr 2022 12:38:45 +0800 Subject: [PATCH 198/234] Add wenet ref in ConvolutionModule class. --- egs/librispeech/ASR/conv_emformer_transducer/emformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index 032ecb77d..317ac4a4b 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -1410,6 +1410,7 @@ class ConvolutionModule(nn.Module): bias=bias, ) + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py # noqa if causal: self.left_padding = kernel_size - 1 padding = 0 From 651745b22036c29483ed9a24e62a90724b73a790 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 11 Apr 2022 12:42:47 +0800 Subject: [PATCH 199/234] minor fix doc in emformer.py --- egs/librispeech/ASR/conv_emformer_transducer/emformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index 317ac4a4b..5ac65141e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -1410,7 +1410,7 @@ class ConvolutionModule(nn.Module): bias=bias, ) - # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py # noqa + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py # noqa if causal: self.left_padding = kernel_size - 1 padding = 0 From 507833208868b8ed07a555437891f208274bbb3d Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 14:58:15 +0800 Subject: [PATCH 200/234] Fix adding learning rate to tensorboard --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c24fbe9a1..b9ea0def6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -695,7 +695,7 @@ def train_one_epoch( ) if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr) + tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train From 03c7c2613d08bb82323e0b3eb2ddc46f0f8260e5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 15:13:42 +0800 Subject: [PATCH 201/234] Fix docs in optim.py --- egs/librispeech/ASR/pruned_transducer_stateless2/optim.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 4f7392d3a..b0d269571 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -240,7 +240,7 @@ class LRScheduler(object): class Eden(LRScheduler): """ Eden scheduler. - lr = initial_lr = (((batch**2 + lr_batches**2) / lr_batchses**2) ** -0.25 * + lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) E.g. suggest initial-lr = 0.003 (passed to optimizer). @@ -250,7 +250,9 @@ class Eden(LRScheduler): lr_batches: the number of batches after which we start significantly decreasing the learning rate, suggest 5000. lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6. + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. """ def __init__(self, optimizer: Optimizer, lr_batches: Union[int, float], From 7012fd65b5175ed7a1003bce1603a8a5d7baa248 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Mon, 11 Apr 2022 16:49:54 +0800 Subject: [PATCH 202/234] Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes --- .../ASR/pruned_transducer_stateless2/model.py | 43 ++++++----- .../ASR/pruned_transducer_stateless2/train.py | 75 ++++++++++++++----- 2 files changed, 80 insertions(+), 38 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index a9178c8b3..81f6df790 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -141,17 +141,21 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=self.simple_lm_proj(decoder_out), - am=self.simple_am_proj(encoder_out), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) + lm=self.simple_lm_proj(decoder_out) + am=self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( @@ -176,13 +180,14 @@ class Transducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, project_input=False) - pruned_loss = k2.rnnt_loss_pruned( - logits=logits, - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b9ea0def6..d08fa15b5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -29,7 +29,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --full-libri 1 \ --max-duration 300 +# For mix precision training: +./pruned_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --use_fp16 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 550 """ @@ -58,6 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eve, Eden from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -249,6 +259,13 @@ def get_parser(): """, ) + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + return parser @@ -447,6 +464,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -460,6 +478,8 @@ def save_checkpoint( The optimizer used in the training. sampler: The sampler for the training dataset. + scaler: + The scaler used for mix precision training. """ if rank != 0: return @@ -471,6 +491,7 @@ def save_checkpoint( optimizer=optimizer, scheduler=scheduler, sampler=sampler, + scaler=scaler, rank=rank, ) @@ -599,6 +620,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -622,6 +644,8 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -644,22 +668,24 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step) - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step) + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - loss.backward() + scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - optimizer.step() + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() if params.print_diagnostics and batch_idx == 5: @@ -676,6 +702,7 @@ def train_one_epoch( optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) del params.cur_batch_idx @@ -695,7 +722,9 @@ def train_one_epoch( ) if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -850,6 +879,11 @@ def run(rank, world_size, args): params=params, ) + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + for epoch in range(params.start_epoch, params.num_epochs): scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) @@ -869,6 +903,7 @@ def run(rank, world_size, args): sp=sp, train_dl=train_dl, valid_dl=valid_dl, + scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -884,6 +919,7 @@ def run(rank, world_size, args): optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) @@ -913,14 +949,15 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup = 0.0 - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup = 0.0 + ) loss.backward() optimizer.step() optimizer.zero_grad() From 8cb727e24a349538b2e43fbc63cedd05c6f8f2da Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 11 Apr 2022 17:08:53 +0800 Subject: [PATCH 203/234] Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T --- README.md | 29 + .../ASR/pruned_transducer_stateless/decode.py | 3 +- .../pruned_transducer_stateless/decoder.py | 4 + .../ASR/pruned_transducer_stateless/train.py | 4 +- egs/tedlium3/ASR/README.md | 8 +- egs/tedlium3/ASR/RESULTS.md | 93 ++- .../pruned_transducer_stateless/__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 744 +++++++++++++++++ .../pruned_transducer_stateless/conformer.py | 1 + .../ASR/pruned_transducer_stateless/decode.py | 528 ++++++++++++ .../pruned_transducer_stateless/decoder.py | 1 + .../encoder_interface.py | 1 + .../ASR/pruned_transducer_stateless/export.py | 184 ++++ .../ASR/pruned_transducer_stateless/joiner.py | 1 + .../ASR/pruned_transducer_stateless/model.py | 1 + .../pruned_transducer_stateless/pretrained.py | 346 ++++++++ .../subsampling.py | 1 + .../test_decoder.py | 61 ++ .../ASR/pruned_transducer_stateless/train.py | 783 ++++++++++++++++++ .../transformer.py | 1 + .../ASR/transducer_stateless/README.md | 2 +- .../ASR/transducer_stateless/decode.py | 6 +- .../ASR/transducer_stateless/export.py | 2 +- .../ASR/transducer_stateless/train.py | 2 +- 25 files changed, 2789 insertions(+), 18 deletions(-) create mode 100644 egs/tedlium3/ASR/pruned_transducer_stateless/__init__.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py create mode 100644 egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py create mode 100755 egs/tedlium3/ASR/pruned_transducer_stateless/decode.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py create mode 100644 egs/tedlium3/ASR/pruned_transducer_stateless/export.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/model.py create mode 100644 egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py create mode 100755 egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py create mode 100755 egs/tedlium3/ASR/pruned_transducer_stateless/train.py create mode 120000 egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py diff --git a/README.md b/README.md index 79d8039ff..5f7f7c7dd 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ We provide four recipes at present: - [LibriSpeech][librispeech] - [Aishell][aishell] - [TIMIT][timit] + - [TED-LIUM3][tedlium3] ### yesno @@ -153,6 +154,31 @@ The PER for this model is: We provide a Colab notebook to run a pre-trained TDNN LiGRU CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/11IT-k4HQIgQngXz1uvWsEYktjqQt7Tmb?usp=sharing) +### TED-LIUM3 + +We provide two models for this recipe: [Transducer Stateless: Conformer encoder + Embedding decoder][TED-LIUM3_transducer_stateless] and [Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TED-LIUM3_pruned_transducer_stateless]. + +#### Transducer Stateless: Conformer encoder + Embedding decoder + +The best WER using modified beam search with beam size 4 is: + +| | dev | test | +|-----|-------|--------| +| WER | 6.91 | 6.33 | + +Note: No auxiliary losses are used in the training and no LMs are used in the decoding. + +We provide a Colab notebook to run a pre-trained Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MmY5bBxwvKLNT4A2DJnwiqRXhdchUqPN?usp=sharing) + +#### Pruned Transducer Stateless: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss + +The best WER using modified beam search with beam size 4 is: + +| | dev | test | +|-----|-------|--------| +| WER | 6.72 | 6.12 | + +We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) ## Deployment with C++ @@ -175,8 +201,11 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc [TIMIT_tdnn_ligru_ctc]: egs/timit/ASR/tdnn_ligru_ctc +[TED-LIUM3_transducer_stateless]: egs/tedlium3/ASR/transducer_stateless +[TED-LIUM3_pruned_transducer_stateless]: egs/tedlium3/ASR/pruned_transducer_stateless [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR [timit]: egs/timit/ASR +[tedlium3]: egs/tedlium3/ASR [k2]: https://github.com/k2-fsa/k2 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 49b1308b0..0e3b0f197 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -483,8 +483,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 8c728fdc5..f4355e8a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -37,6 +37,7 @@ class Decoder(nn.Module): vocab_size: int, embedding_dim: int, blank_id: int, + unk_id: int, context_size: int, ): """ @@ -47,6 +48,8 @@ class Decoder(nn.Module): Dimension of the input embedding. blank_id: The ID of the blank symbol. + unk_id: + The ID of the unk symbol. context_size: Number of previous words to use to predict the next word. 1 means bigram; 2 means trigram. n means (n+1)-gram. @@ -58,6 +61,7 @@ class Decoder(nn.Module): padding_idx=blank_id, ) self.blank_id = blank_id + self.unk_id = unk_id assert context_size >= 1, context_size self.context_size = context_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e743106ec..f0ea12d62 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -319,6 +319,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, embedding_dim=params.embedding_dim, blank_id=params.blank_id, + unk_id=params.unk_id, context_size=params.context_size, ) return decoder @@ -756,8 +757,9 @@ def run(rank, world_size, args): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # is defined in local/train_bpe_model.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) diff --git a/egs/tedlium3/ASR/README.md b/egs/tedlium3/ASR/README.md index 57bd9458b..0740258a7 100644 --- a/egs/tedlium3/ASR/README.md +++ b/egs/tedlium3/ASR/README.md @@ -8,10 +8,10 @@ This recipe includes some different ASR models trained with TedLium3. There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. -| | Encoder | Decoder | -|------------------------|-----------|--------------------| -| `transducer_stateless` | Conformer | Embedding + Conv1d | - +| | Encoder | Decoder | Comment | +|----------------------------------|-----------|--------------------|-----------------------------| +| `transducer_stateless` | Conformer | Embedding + Conv1d | | +| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md index 407b77fa8..beeeb047b 100644 --- a/egs/tedlium3/ASR/RESULTS.md +++ b/egs/tedlium3/ASR/RESULTS.md @@ -1,9 +1,90 @@ ## Results +### TedLium3 BPE training results (Pruned Transducer) + +#### 2022-03-21 + +Using the codes from this PR https://github.com/k2-fsa/icefall/pull/261. + +The WERs are + +| | dev | test | comment | +|------------------------------------|------------|------------|------------------------------------------| +| greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 | +| beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 | +| modified beam search (beam size 4) | 6.77 | 6.12 | --epoch 29, --avg 13, --max-duration 100 | +| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500| + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless/exp \ + --max-duration 300 +``` + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/VpA8b7SZQ7CEjZs9WZ5HNA/#scalars + +The decoding command is: +``` +epoch=29 +avg=13 + +## greedy search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 + +## beam search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +## modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +## fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless/exp \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + +A pre-trained model and decoding logs can be found at + ### TedLium3 BPE training results (Transducer) #### Conformer encoder + embedding decoder +##### 2022-03-21 + Using the codes from this PR https://github.com/k2-fsa/icefall/pull/233 And the SpecAugment codes from this PR https://github.com/lhotse-speech/lhotse/pull/604 @@ -14,9 +95,9 @@ The WERs are | | dev | test | comment | |------------------------------------|------------|------------|------------------------------------------| -| greedy search | 7.19 | 6.57 | --epoch 29, --avg 16, --max-duration 100 | -| beam search (beam size 4) | 7.12 | 6.37 | --epoch 29, --avg 16, --max-duration 100 | -| modified beam search (beam size 4) | 7.00 | 6.19 | --epoch 29, --avg 16, --max-duration 100 | +| greedy search | 7.19 | 6.70 | --epoch 29, --avg 11, --max-duration 100 | +| beam search (beam size 4) | 7.02 | 6.36 | --epoch 29, --avg 11, --max-duration 100 | +| modified beam search (beam size 4) | 6.91 | 6.33 | --epoch 29, --avg 11, --max-duration 100 | The training command for reproducing is given below: @@ -28,16 +109,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --max-duration 200 + --max-duration 300 ``` The tensorboard training log can be found at -https://tensorboard.dev/experiment/zrfXeJO3Q5GmJpP2KRd2VA/#scalars +https://tensorboard.dev/experiment/4ks15jYHR4uMyvpW7Nz76Q/#scalars The decoding command is: ``` epoch=29 -avg=16 +avg=11 ## greedy search ./transducer_stateless/decode.py \ diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/__init__.py b/egs/tedlium3/ASR/pruned_transducer_stateless/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py new file mode 120000 index 000000000..49b2ee483 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/asr_datamodule.py @@ -0,0 +1 @@ +../transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py new file mode 100644 index 000000000..0ae001d3f --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -0,0 +1,744 @@ +# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import k2 +import torch +from model import Transducer + +from icefall.decode import one_best_decoding +from icefall.utils import get_texts + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + unk_id = model.decoder.unk_id + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + # current_encoder_out is of shape + # (shape.NumElements(), 1, encoder_out_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + new_hyps = [] + for hyp in hyps: + hyp = [idx for idx in hyp if idx != unk_id] + new_hyps.append(hyp) + return new_hyps + + +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: + """Greedy search for a single utterance. + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device, dtype=torch.int64 + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + t = 0 + hyp = [blank_id] * context_size + + # Maximum symbols per utterance. + max_sym_per_utt = 1000 + + # symbols per frame + sym_per_frame = 0 + + # symbols per utterance decoded so far + sym_per_utt = 0 + + while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits is (1, 1, 1, vocab_size) + + y = logits.argmax().item() + if y != blank_id and y != unk_id: + hyp.append(y) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + sym_per_utt += 1 + sym_per_frame += 1 + else: + sym_per_frame = 0 + t += 1 + hyp = hyp[context_size:] # remove blanks + + return hyp + + +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list integers containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id and v != unk_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor(decoder_input, device=device) + decoder_out = model.decoder(decoder_input, need_pad=False) + + ans = [h[context_size:] for h in hyps] + return ans + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id and new_token != unk_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id and new_token != unk_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """ + It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf + + espnet/nets/beam_search_transducer.py#L247 is used as a reference. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + t = 0 + + B = HypothesisList() + B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + + max_sym_per_utt = 20000 + + sym_per_utt = 0 + + decoder_cache: Dict[str, torch.Tensor] = {} + + while t < T and sym_per_utt < max_sym_per_utt: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # fmt: on + A = B + B = HypothesisList() + + joint_cache: Dict[str, torch.Tensor] = {} + + # TODO(fangjun): Implement prefix search to update the `log_prob` + # of hypotheses in A + + while True: + y_star = A.get_most_probable() + A.remove(y_star) + + cached_key = y_star.key + + if cached_key not in decoder_cache: + decoder_input = torch.tensor( + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_cache[cached_key] = decoder_out + else: + decoder_out = decoder_cache[cached_key] + + cached_key += f"-t-{t}" + if cached_key not in joint_cache: + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) + + # TODO(fangjun): Scale the blank posterior + + log_prob = logits.log_softmax(dim=-1) + # log_prob is (1, 1, 1, vocab_size) + log_prob = log_prob.squeeze() + # Now log_prob is (vocab_size,) + joint_cache[cached_key] = log_prob + else: + log_prob = joint_cache[cached_key] + + # First, process the blank symbol + skip_log_prob = log_prob[blank_id] + new_y_star_log_prob = y_star.log_prob + skip_log_prob + + # ys[:] returns a copy of ys + B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + + # Second, process other non-blank labels + values, indices = log_prob.topk(beam + 1) + for i, v in zip(indices.tolist(), values.tolist()): + if i == blank_id or i == unk_id: + continue + new_ys = y_star.ys + [i] + new_log_prob = y_star.log_prob + v + A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + + # Check whether B contains more than "beam" elements more probable + # than the most probable in A + A_most_probable = A.get_most_probable() + + kept_B = B.filter(A_most_probable.log_prob) + + if len(kept_B) >= beam: + B = kept_B.topk(beam) + break + + t += 1 + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + return ys diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py b/egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py new file mode 120000 index 000000000..8be0dc864 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py new file mode 100755 index 000000000..fd8d2dd0e --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -0,0 +1,528 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless/decode.py \ + --epoch 29 \ + --avg 13 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./pruned_transducer_stateless/decode.py \ + --epoch 29 \ + --avg 13 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch 29 \ + --avg 13 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless/decode.py \ + --epoch 29 \ + --avg 13 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import TedLiumAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=13, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + batch=batch, + decoding_graph=decoding_graph, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + TedLiumAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + tedlium = TedLiumAsrDataModule(args) + dev_cuts = tedlium.dev_cuts() + test_cuts = tedlium.test_cuts() + + dev_dl = tedlium.valid_dataloaders(dev_cuts) + test_dl = tedlium.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py new file mode 120000 index 000000000..206384eaa --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py b/egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py new file mode 100644 index 000000000..1e6edbb99 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless/export.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 29 \ + --avg 13 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/tedlium3/ASR + ./pruned_transducer_stateless/decode.py \ + --exp-dir ./pruned_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=13, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py b/egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py new file mode 120000 index 000000000..b3d677eb5 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless/joiner.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/model.py b/egs/tedlium3/ASR/pruned_transducer_stateless/model.py new file mode 120000 index 000000000..6b78aed54 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless/model.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py new file mode 100644 index 000000000..2c795ede0 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 Xiaomi Crop. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless/pretrained.py \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + --max-sym-per-frame 1 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless/pretrained.py \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) modified beam search +./pruned_transducer_stateless/pretrained.py \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by +./pruned_transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torch.nn as nn +import torchaudio +from beam_search import beam_search, greedy_search, modified_beam_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer +from torch.nn.utils.rnn import pad_sequence + +from icefall.env import get_env_info +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search and modified_beam_search ", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 512, + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.vocab_size, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + unk_id=params.unk_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py b/egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py new file mode 120000 index 000000000..fd7ca8b30 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py b/egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py new file mode 100755 index 000000000..b97bf6150 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/test_decoder.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/tedlium3/ASR + python ./pruned_transducer_stateless/test_decoder.py +""" + +import torch +from decoder import Decoder + + +def test_decoder(): + vocab_size = 3 + blank_id = 0 + unk_id = 2 + embedding_dim = 128 + context_size = 4 + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + unk_id=unk_id, + context_size=context_size, + ) + N = 100 + U = 20 + x = torch.randint(low=0, high=vocab_size, size=(N, U)) + y = decoder(x) + assert y.shape == (N, U, vocab_size) + + # for inference + x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) + y = decoder(x, need_pad=False) + assert y.shape == (N, 1, vocab_size) + + +def main(): + test_decoder() + + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py new file mode 100755 index 000000000..b6fc9a926 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -0,0 +1,783 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless/exp \ + --max-duration 300 +""" + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import TedLiumAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids +from model import Transducer +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12350, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_stateless/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - attention_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + # parameters for decoder + "embedding_dim": 512, + # parameters for Noam + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.vocab_size, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + unk_id=params.unk_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + unk_id = params.unk_id + y = convert_texts_into_ids(texts, unk_id, sp=sp) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + tedlium = TedLiumAsrDataModule(args) + + train_cuts = tedlium.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 17 seconds + return 1.0 <= c.duration <= 17.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info(f"After removing short and long utterances: {num_left}") + logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + + train_dl = tedlium.train_dataloaders(train_cuts) + valid_cuts = tedlium.dev_cuts() + valid_dl = tedlium.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + TedLiumAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py b/egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py new file mode 120000 index 000000000..214afed39 --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/transformer.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/README.md b/egs/tedlium3/ASR/transducer_stateless/README.md index 93af553ec..9b6ed62f1 100644 --- a/egs/tedlium3/ASR/transducer_stateless/README.md +++ b/egs/tedlium3/ASR/transducer_stateless/README.md @@ -16,5 +16,5 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --max-duration 200 + --max-duration 300 ``` diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index c566132b0..3185e7581 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -21,7 +21,7 @@ Usage: (1) greedy search ./transducer_stateless/decode.py \ --epoch 29 \ - --avg 16 \ + --avg 11 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method greedy_search @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless/decode.py \ --epoch 29 \ - --avg 16 \ + --avg 11 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method beam_search \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless/decode.py \ --epoch 29 \ - --avg 16 \ + --avg 11 \ --exp-dir ./transducer_stateless/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py index 6a40a1b4f..f2bfa2ec9 100644 --- a/egs/tedlium3/ASR/transducer_stateless/export.py +++ b/egs/tedlium3/ASR/transducer_stateless/export.py @@ -25,7 +25,7 @@ Usage: --exp-dir ./transducer_stateless/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 29 \ - --avg 16 + --avg 11 It will generate a file exp_dir/pretrained.pt diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index 11a864423..dda6108c5 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --max-duration 200 + --max-duration 300 """ From cc0d4ffa4f115e6345274c203716e935b369cb77 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 11 Apr 2022 15:27:24 +0800 Subject: [PATCH 204/234] Add mix precision support --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d08fa15b5..c78a0f1c3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -870,7 +870,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: + if not params.print_diagnostics and not params.use_fp16: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, From ddd8f9e15ef33aa86f2b3f52278d75cdbe0138de Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 11 Apr 2022 15:40:14 +0800 Subject: [PATCH 205/234] Minor fixes --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index c78a0f1c3..31b85d53c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -39,7 +39,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 550 - """ From a92133ef960a43d3e9f4834594acad2051c9aa22 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 11 Apr 2022 15:41:45 +0800 Subject: [PATCH 206/234] Minor fixes --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 31b85d53c..577231995 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -869,7 +869,7 @@ def run(rank, world_size, args): valid_cuts += librispeech.dev_other_cuts() valid_dl = librispeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics and not params.use_fp16: + if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, From e8eb0b94d912c08afd9adce3675091f05baf3cf0 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 20:56:11 +0800 Subject: [PATCH 207/234] Updating RESULTS.md; fix in beam_search.py --- egs/librispeech/ASR/README.md | 16 ++-- egs/librispeech/ASR/RESULTS.md | 78 +++++++++++++++++++ .../beam_search.py | 2 +- 3 files changed, 88 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index a7b2e2c3b..b3e90a052 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -9,13 +9,15 @@ for how to run models in this recipe. There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. -| | Encoder | Decoder | Comment | -|---------------------------------------|-----------|--------------------|---------------------------------------------------| -| `transducer` | Conformer | LSTM | | -| `transducer_stateless` | Conformer | Embedding + Conv1d | | -| `transducer_lstm` | LSTM | LSTM | | -| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | -| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `transducer` | Conformer | LSTM | | +| `transducer_stateless` | Conformer | Embedding + Conv1d | | +| `transducer_lstm` | LSTM | LSTM | | +| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | +| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | +| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | + The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 6dbc659f7..ce90da356 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,79 @@ ## Results +### LibriSpeech BPE training results (Pruned Transducer 2) + +This is with a reworked version of the conformer encoder, with many changes. + +[pruned_transducer_stateless2](./pruned_transducer_stateless2) + +using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`. +See + +The WERs are: + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-------------------------------------------------------------------------------| +| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 25, --avg 8, --max-duration 600 | +| fast beam search | 2.61 | 6.17 | --epoch 25, --avg 8, --max-duration 600 --decoding-method fast_beam_search | +| modified beam search | 2.59 | 6.19 | --epoch 25, --avg 8, --max-duration 600 --decoding-method modified_beam_search| + + +The train and decode commands are: +`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp --world-size 8 --num-epochs 26 --full-libri 1 --max-duration 300` +and: +`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp --epoch 25 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600` + +The Tensorboard log is at + + +The WERs for librispeech 100 hours are: + +Trained with one job: +`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws1 --world-size 1 --num-epochs 40 --full-libri 0 --max-duration 300` +and decoded with: +`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp_100h_ws1 --epoch 19 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`. + +The Tensorboard log is at (learning rate +schedule is not visible due to a since-fixed bug). + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-------------------------------------------------------| +| greedy search (max sym per frame 1) | 7.12 | 18.42 | --epoch 19 --avg 8 | +| greedy search (max sym per frame 1) | 6.71 | 17.77 | --epoch 29 --avg 8 | +| fast beam search | 6.58 | 17.27 | --epoch 19 --avg 8 --decoding-method fast_beam_search | + +Trained with two jobs: +`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws2 --world-size 2 --num-epochs 40 --full-libri 0 --max-duration 300` +and decoded with: +`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp_100h_ws2 --epoch 19 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`. + +The Tensorboard log is at +(learning rate schedule is not visible due to a since-fixed bug). + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-----------------------| +| greedy search (max sym per frame 1) | 7.05 | 18.77 | --epoch 19, --avg 8 | +| greedy search (max sym per frame 1) | 6.82 | 18.14 | --epoch 29, --avg 8 | +| greedy search (max sym per frame 1) | 6.81 | 17.66 | --epoch 30, --avg 10 | + + +Trained with 4 jobs: +`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws4 --world-size 4 --num-epochs 40 --full-libri 0 --max-duration 300` +and decoded with: +`python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp_100h_ws4 --epoch 19 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600`. + + +The Tensorboard log is at +(learning rate schedule is not visible due to a since-fixed bug). + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-----------------------| +| greedy search (max sym per frame 1) | 7.31 | 19.55 | --epoch 19, --avg 8 | +| greedy search (max sym per frame 1) | 7.08 | 18.59 | --epoch 29, --avg 8 | +| greedy search (max sym per frame 1) | 6.86 | 18.29 | --epoch 30, --avg 10 | + + + ### LibriSpeech BPE training results (Pruned Transducer) Conformer encoder + non-current decoder. The decoder @@ -23,6 +97,10 @@ The WERs are: | modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 42, --avg 11, --max-duration 100 | | beam search (beam size 4) | 2.57 | 6.27 | --epoch 42, --avg 11, --max-duration 100 | + + + + The decoding time for `test-clean` and `test-other` is given below: (A V100 GPU with 32 GB RAM is used for decoding. Note: Not all GPU RAM is used during decoding.) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5876d5158..d0e5c083f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -89,7 +89,7 @@ def fast_beam_search( # (shape.NumElements(), 1, joiner_dim) # fmt: off current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) ) # fmt: on logits = model.joiner( From ead822477c3c51190b78a61c48cb29ff8c198cba Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 11 Apr 2022 21:01:13 +0800 Subject: [PATCH 208/234] Fix rebase --- egs/librispeech/ASR/pruned_transducer_stateless2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 577231995..d08fa15b5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -39,6 +39,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --exp-dir pruned_transducer_stateless2/exp \ --full-libri 1 \ --max-duration 550 + """ From 93c60a9d30a4f6d5e2a200a76efbfe6ef3d1bd53 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 11 Apr 2022 22:15:18 +0800 Subject: [PATCH 209/234] Code style check for librispeech pruned transducer stateless2 (#308) --- .flake8 | 2 + .../beam_search.py | 19 +- .../pruned_transducer_stateless2/conformer.py | 123 +++--- .../pruned_transducer_stateless2/decoder.py | 4 +- .../pruned_transducer_stateless2/joiner.py | 24 +- .../ASR/pruned_transducer_stateless2/model.py | 24 +- .../ASR/pruned_transducer_stateless2/optim.py | 145 ++++--- .../pruned_transducer_stateless2/scaling.py | 364 ++++++++++++------ .../ASR/pruned_transducer_stateless2/train.py | 69 ++-- .../beam_search.py | 4 +- icefall/checkpoint.py | 2 +- 11 files changed, 484 insertions(+), 296 deletions(-) diff --git a/.flake8 b/.flake8 index dd9239b2d..5b3c444b8 100644 --- a/.flake8 +++ b/.flake8 @@ -7,6 +7,8 @@ per-file-ignores = egs/librispeech/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501, egs/tedlium3/ASR/*/conformer.py: E501, + egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501, + # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5876d5158..fae1d5a96 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -93,7 +93,9 @@ def fast_beam_search( ) # fmt: on logits = model.joiner( - current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1), project_input=False + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, ) logits = logits.squeeze(1).squeeze(1) log_probs = logits.log_softmax(dim=-1) @@ -140,7 +142,6 @@ def greedy_search( encoder_out = model.joiner.encoder_proj(encoder_out) - T = encoder_out.size(1) t = 0 hyp = [blank_id] * context_size @@ -163,9 +164,9 @@ def greedy_search( # fmt: off current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) # fmt: on - logits = model.joiner(current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() @@ -228,8 +229,9 @@ def greedy_search_batch( for t in range(T): current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1), - project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) # logits'shape (batch_size, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) @@ -466,7 +468,6 @@ def modified_beam_search( decoder_out = model.joiner.decoder_proj(decoder_out) # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor # as index, so we use `to(torch.int64)` below. current_encoder_out = torch.index_select( @@ -720,7 +721,7 @@ def beam_search( logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), - project_input=False + project_input=False, ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 94c6aa90c..257936b59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -16,13 +16,20 @@ # limitations under the License. import copy -from encoder_interface import EncoderInterface import math import warnings -from typing import Optional, Tuple, Sequence -from scaling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d +from typing import Optional, Tuple import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) from torch import Tensor, nn from icefall.utils import make_pad_mask @@ -42,6 +49,7 @@ class Conformer(EncoderInterface): cnn_module_kernel (int): Kernel size of convolution module vgg_frontend (bool): whether to use vgg frontend. """ + def __init__( self, num_features: int, @@ -80,9 +88,8 @@ class Conformer(EncoderInterface): ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -112,8 +119,9 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask, - warmup=warmup) # (T, N, C) + x = self.encoder( + x, pos_emb, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) @@ -176,18 +184,15 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0) + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) self.dropout = nn.Dropout(dropout) - def forward( self, src: Tensor, @@ -220,14 +225,17 @@ class ConformerEncoderLayer(nn.Module): # alpha = 1.0 means fully use this encoder layer, 0.0 would mean # completely bypass it. if self.training: - alpha = warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1 + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) else: alpha = 1.0 # macaron style feed forward module src = src + self.dropout(self.feed_forward_macaron(src)) - # multi-headed self-attention module src_att = self.self_attn( src, @@ -248,7 +256,7 @@ class ConformerEncoderLayer(nn.Module): src = self.norm_final(self.balancer(src)) if alpha != 1.0: - src = alpha * src + (1-alpha) * src_orig + src = alpha * src + (1 - alpha) * src_orig return src @@ -275,14 +283,13 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers - def forward( self, src: Tensor, pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0 + warmup: float = 1.0, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -302,8 +309,6 @@ class ConformerEncoder(nn.Module): """ output = src - num_layers = len(self.layers) - for i, mod in enumerate(self.layers): output = mod( output, @@ -428,7 +433,9 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -621,7 +628,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -653,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:_end] q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -672,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -864,9 +871,9 @@ class ConvolutionModule(nn.Module): # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, # it will be in a better position to start learning something, i.e. to latch onto # the correct range. - self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0, - min_positive=0.05, - max_positive=1.0) + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) self.depthwise_conv = ScaledConv1d( channels, @@ -878,9 +885,9 @@ class ConvolutionModule(nn.Module): bias=bias, ) - self.deriv_balancer2 = ActivationBalancer(channel_dim=1, - min_positive=0.05, - max_positive=1.0) + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) self.activation = DoubleSwish() @@ -891,7 +898,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, - initial_scale=0.25 + initial_scale=0.25, ) def forward(self, x: Tensor) -> Tensor: @@ -924,7 +931,6 @@ class ConvolutionModule(nn.Module): return x.permute(2, 0, 1) - class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). @@ -936,11 +942,14 @@ class Conv2dSubsampling(nn.Module): https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa """ - def __init__(self, in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + ) -> None: """ Args: in_channels: @@ -958,34 +967,41 @@ class Conv2dSubsampling(nn.Module): self.conv = nn.Sequential( ScaledConv2d( - in_channels=1, out_channels=layer1_channels, - kernel_size=3, padding=1, + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( - in_channels=layer1_channels, out_channels=layer2_channels, - kernel_size=3, stride=2, + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( - in_channels=layer2_channels, out_channels=layer3_channels, - kernel_size=3, stride=2, + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ) - self.out = ScaledLinear(layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1009,13 +1025,14 @@ class Conv2dSubsampling(nn.Module): return x - -if __name__ == '__main__': +if __name__ == "__main__": feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. - f = c(torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5) + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index c23568ae9..b6d94aaf1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -17,9 +17,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -from typing import Optional -from scaling import ScaledConv1d, ScaledLinear, ScaledEmbedding +from scaling import ScaledConv1d, ScaledEmbedding class Decoder(nn.Module): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 2299a0a8c..35f75ed2a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -16,15 +16,17 @@ import torch import torch.nn as nn -import torch.nn.functional as F from scaling import ScaledLinear + class Joiner(nn.Module): - def __init__(self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int): + def __init__( + self, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): super().__init__() self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) @@ -32,8 +34,10 @@ class Joiner(nn.Module): self.output_linear = ScaledLinear(joiner_dim, vocab_size) def forward( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, - project_input: bool = True + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + project_input: bool = True, ) -> torch.Tensor: """ Args: @@ -52,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 81f6df790..599bf2506 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -37,7 +37,7 @@ class Transducer(nn.Module): encoder_dim: int, decoder_dim: int, joiner_dim: int, - vocab_size: int + vocab_size: int, ): """ Args: @@ -48,11 +48,11 @@ class Transducer(nn.Module): `logit_lens` of shape (N,). decoder: It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). It should contain - one attribute: `blank_id`. + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its - output shape is (N, T, U, vocab_size). Note that its output contains + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() @@ -63,8 +63,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, - initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -141,8 +142,8 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - lm=self.simple_lm_proj(decoder_out) - am=self.simple_am_proj(encoder_out) + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( @@ -170,15 +171,14 @@ class Transducer(nn.Module): am_pruned, lm_pruned = k2.do_rnnt_pruning( am=self.joiner.encoder_proj(encoder_out), lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges + ranges=ranges, ) # logits : [B, T, prune_range, vocab_size] # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, - project_input=False) + logits = self.joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index b0d269571..432bf8220 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -15,11 +15,9 @@ # limitations under the License. -import random -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch -from torch import Tensor from torch.optim import Optimizer @@ -59,24 +57,41 @@ class Eve(Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-8, - weight_decay=1e-3, target_rms=0.1): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - target_rms=target_rms) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) super(Eve, self).__init__(params, defaults) def __setstate__(self, state): @@ -96,83 +111,98 @@ class Eve(Optimizer): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - beta1, beta2 = group['betas'] + beta1, beta2 = group["betas"] - state['step'] += 1 - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(group['eps']) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) - step_size = group['lr'] / bias_correction1 - target_rms = group['target_rms'] - weight_decay = group['weight_decay'] - delta = exp_avg / denom + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = (p.norm() > (target_rms * (p.numel() ** 0.5))) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) return loss + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the batch and the epoch. """ + def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError('{} is not an Optimizer'.format( - type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: - group.setdefault('initial_lr', group['lr']) + group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 - def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. """ - return {'base_lrs': self.base_lrs, - 'epoch': self.epoch, - 'batch': self.batch} + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } def load_state_dict(self, state_dict): """Loads the schedulers state. @@ -184,8 +214,7 @@ class LRScheduler(object): self.__dict__.update(state_dict) def get_last_lr(self) -> List[float]: - """ Return last computed learning rate by current scheduler. Will be a list of float. - """ + """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr def get_lr(self): @@ -194,7 +223,6 @@ class LRScheduler(object): # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] raise NotImplementedError - def step_batch(self, batch: Optional[int] = None) -> None: # Step the batch index, or just set it. If `batch` is specified, it # must be the batch index from the start of training, i.e. summed over @@ -217,24 +245,23 @@ class LRScheduler(object): self.epoch = self.epoch + 1 self._set_lrs() - def _set_lrs(self): values = self.get_lr() assert len(values) == len(self.optimizer.param_groups) for i, data in enumerate(zip(self.optimizer.param_groups, values)): param_group, lr = data - param_group['lr'] = lr + param_group["lr"] = lr self.print_lr(self.verbose, i, lr) - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] - + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def print_lr(self, is_verbose, group, lr): - """Display the current learning rate. - """ + """Display the current learning rate.""" if is_verbose: - print(f'Epoch={self.epoch}, batch={self.batch}: adjusting learning rate' - f' of group {group} to {lr:.4e}.') + print( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) class Eden(LRScheduler): @@ -254,18 +281,27 @@ class Eden(LRScheduler): 20 to 40 epochs, but may need smaller number if dataset is huge and you will do few epochs. """ - def __init__(self, optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - verbose: bool = False): + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + verbose: bool = False, + ): super(Eden, self).__init__(optimizer, verbose) self.lr_batches = lr_batches self.lr_epochs = lr_epochs def get_lr(self): - factor = (((self.batch**2 + self.lr_batches**2) / self.lr_batches**2) ** -0.25 * - (((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25)) - return [ x * factor for x in self.base_lrs ] + factor = ( + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + ) ** -0.25 * ( + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 + ) + return [x * factor for x in self.base_lrs] + def _test_eden(): m = torch.nn.Linear(100, 100) @@ -290,5 +326,6 @@ def _test_eden(): print("last lr = ", scheduler.get_last_lr()) print("state dict = ", scheduler.state_dict()) -if __name__ == '__main__': + +if __name__ == "__main__": _test_eden() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 98a56ce77..d59aa2160 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -15,54 +15,86 @@ # limitations under the License. +import collections +from itertools import repeat +from typing import Optional, Tuple + import torch import torch.nn as nn from torch import Tensor -from typing import Tuple, Optional +def _ntuple(n): + def parse(x): + if isinstance(x, collections.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward( + ctx, + x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) + proportion_positive = torch.mean( + xgt0.to(x.dtype), dim=sum_dims, keepdim=True + ) + factor1 = ( + (min_positive - proportion_positive).relu() + * (max_factor / min_positive) + if min_positive != 0.0 + else 0.0 + ) + factor2 = ( + (proportion_positive - max_positive).relu() + * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 + else 0.0 + ) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) + below_threshold = mean_abs < min_abs + above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) + scale_factor = ( + (below_threshold.to(dtype) - above_threshold.to(dtype)) + * (xgt0.to(dtype) - 0.5) + * (ctx.max_factor * 2.0) + ) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -95,29 +127,31 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True) -> None: + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + ) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer('eps', torch.tensor(eps).log().detach()) - + self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - self.eps.exp()) ** -0.5 + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() + ) ** -0.5 return x * scales - - class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -143,19 +177,25 @@ class ScaledLinear(nn.Linear): Alternatively you can set it to more than 1 if you want it to initially train faster. Must be greater than 0. """ - def __init__(self, *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs): + + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) + self.register_parameter("bias_scale", None) - self._reset_parameters(initial_speed) # Overrides the reset_parameters in nn.Linear + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in nn.Linear def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed @@ -172,28 +212,33 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear - def __init__(self, *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs): + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) - self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed @@ -206,39 +251,58 @@ class ScaledConv1d(nn.Conv1d): with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - + if self.padding_mode != "zeros": + return F.conv1d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + self.get_weight(), + self.get_bias(), + self.stride, + _single(0), + self.dilation, + self.groups, + ) + return F.conv1d( + input, + self.get_weight(), + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear - def __init__(self, *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs): + def __init__( + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs + ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter('bias_scale', None) - self._reset_parameters(initial_speed) # Overrides the reset_parameters in base class + self.register_parameter("bias_scale", None) + self._reset_parameters( + initial_speed + ) # Overrides the reset_parameters in base class def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed @@ -251,29 +315,42 @@ class ScaledConv2d(nn.Conv2d): with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() - def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return (None if self.bias is None else - self.bias * self.bias_scale.exp()) + return None if self.bias is None else self.bias * self.bias_scale.exp() def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) + if self.padding_mode != "zeros": + return F.conv2d( + F.pad( + input, + self._reversed_padding_repeated_twice, + mode=self.padding_mode, + ), + weight, + self.get_bias(), + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.get_bias(), + self.stride, + self.padding, + self.dilation, + self.groups, + ) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) - - class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -302,12 +379,16 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): + + def __init__( + self, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0, + ): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -317,10 +398,15 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): @@ -338,6 +424,7 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ + @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -349,18 +436,17 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1-s) + s) * y_grad + return (y * (1 - s) + s) * y_grad + class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) - - class ScaledEmbedding(nn.Module): r"""This is a modified version of nn.Embedding that introduces a learnable scale on the parameters. Note: due to how we initialize it, it's best used with @@ -443,8 +529,13 @@ class ScaledEmbedding(nn.Module): [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', - 'scale_grad_by_freq', 'sparse'] + __constants__ = [ + "num_embeddings", + "embedding_dim", + "padding_idx", + "scale_grad_by_freq", + "sparse", + ] num_embeddings: int embedding_dim: int @@ -453,33 +544,41 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - initial_speed: float = 1.0) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + initial_speed: float = 1.0, + ) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters(initial_speed) - def reset_parameters(self, initial_speed: float = 1.0) -> None: std = 0.1 / initial_speed nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -489,36 +588,53 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return F.embedding( - input, self.weight, self.padding_idx, - None, 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, self.sparse) * scale + return ( + F.embedding( + input, + self.weight, + self.padding_idx, + None, + 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, + self.sparse, + ) + * scale + ) else: return F.embedding( - input, self.weight * scale, self.padding_idx, - None, 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, self.sparse) + input, + self.weight * scale, + self.padding_idx, + None, + 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, + self.sparse, + ) def extra_repr(self) -> str: - s = '{num_embeddings}, {embedding_dim}, scale={scale}' + s = "{num_embeddings}, {embedding_dim}, scale={scale}" if self.padding_idx is not None: - s += ', padding_idx={padding_idx}' + s += ", padding_idx={padding_idx}" if self.scale_grad_by_freq is not False: - s += ', scale_grad_by_freq={scale_grad_by_freq}' + s += ", scale_grad_by_freq={scale_grad_by_freq}" if self.sparse is not False: - s += ', sparse=True' + s += ", sparse=True" return s.format(**self.__dict__) def _test_activation_balancer_sign(): - channel_dim = 0 probs = torch.arange(0, 1, 0.01) N = 1000 x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -528,17 +644,23 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) + def _test_activation_balancer_magnitude(): - channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) + m = ActivationBalancer( + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + ) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -558,8 +680,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -573,7 +695,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == '__main__': +if __name__ == "__main__": _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index d08fa15b5..80617847a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -45,16 +45,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging -import math import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import optim import sentencepiece as spm import torch -import optim # from . import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -65,27 +64,24 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eve, Eden +from optim import Eden, Eve from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall import diagnostics +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -from icefall.utils import ( - AttributeDict, - MetricsTracker, - setup_logger, - str2bool, -) +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): parser = argparse.ArgumentParser( @@ -168,7 +164,7 @@ def get_parser(): type=float, default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""" + We suggest not to change this.""", ) parser.add_argument( @@ -176,7 +172,7 @@ def get_parser(): type=float, default=6, help="""Number of epochs that affects how rapidly the learning rate decreases. - """ + """, ) parser.add_argument( @@ -335,7 +331,7 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -510,7 +506,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup: float = 1.0 + warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -557,18 +553,24 @@ def compute_loss( # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. - pruned_loss_scale = (0.0 if warmup < 1.0 else - (0.1 if warmup > 1.0 and warmup < 2.0 else - 1.0)) - loss = (params.simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss) + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -675,7 +677,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step) + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -691,8 +693,10 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return - if (params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0): + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, @@ -723,7 +727,7 @@ def train_one_epoch( if tb_writer is not None: tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train + "train/learning_rate", cur_lr, params.batch_idx_train ) loss_info.write_summary( @@ -813,18 +817,19 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - optimizer = Eve( - model.parameters(), - lr=params.initial_lr) + optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) - if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None: + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) @@ -834,7 +839,6 @@ def run(rank, world_size, args): ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) - librispeech = LibriSpeechAsrDataModule(args) train_cuts = librispeech.train_clean_100_cuts() @@ -889,7 +893,6 @@ def run(rank, world_size, args): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -956,7 +959,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup = 0.0 + warmup=0.0, ) loss.backward() optimizer.step() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py index 0ae001d3f..3a08b100d 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -486,7 +486,9 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc") + topk_hyp_indexes = torch.div( + topk_indexes, vocab_size, rounding_mode="trunc" + ) topk_hyp_indexes = topk_hyp_indexes.tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 4dbabe7dc..cc167292b 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -29,11 +29,11 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer - # use duck typing for LRScheduler since we have different possibilities, see # our class LRScheduler. LRSchedulerType = object + def save_checkpoint( filename: Path, model: Union[nn.Module, DDP], From 118e195004ef41f07a26018fff87ee79acea9d31 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 11 Apr 2022 22:19:26 +0800 Subject: [PATCH 210/234] Update results for tedlium3 pruned RNN-T (#307) * Update README.md --- README.md | 2 +- egs/tedlium3/ASR/RESULTS.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5f7f7c7dd..6adba4955 100644 --- a/README.md +++ b/README.md @@ -176,7 +176,7 @@ The best WER using modified beam search with beam size 4 is: | | dev | test | |-----|-------|--------| -| WER | 6.72 | 6.12 | +| WER | 6.77 | 6.14 | We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1je_1zGrOkGVVd4WLzgkXRHxl-I27yWtz?usp=sharing) diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md index beeeb047b..511b19f73 100644 --- a/egs/tedlium3/ASR/RESULTS.md +++ b/egs/tedlium3/ASR/RESULTS.md @@ -12,7 +12,7 @@ The WERs are |------------------------------------|------------|------------|------------------------------------------| | greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 | | beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 | -| modified beam search (beam size 4) | 6.77 | 6.12 | --epoch 29, --avg 13, --max-duration 100 | +| modified beam search (beam size 4) | 6.77 | 6.14 | --epoch 29, --avg 13, --max-duration 100 | | fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500| The training command for reproducing is given below: From bdeff338c2245b3980d0385ffa37e4468aa6a02e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Apr 2022 09:09:56 +0800 Subject: [PATCH 211/234] Fix CI errors. (#310) --- egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index b0eb4d749..3cc472974 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -174,6 +174,7 @@ def main(): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(f"{params}") From 65818d16ded697d6b11c65addc002ac5faae2eaf Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 12 Apr 2022 11:48:16 +0800 Subject: [PATCH 212/234] Add more results --- egs/librispeech/ASR/RESULTS.md | 71 +++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ce90da356..645e24fdc 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -13,9 +13,15 @@ The WERs are: | | test-clean | test-other | comment | |-------------------------------------|------------|------------|-------------------------------------------------------------------------------| -| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 25, --avg 8, --max-duration 600 | -| fast beam search | 2.61 | 6.17 | --epoch 25, --avg 8, --max-duration 600 --decoding-method fast_beam_search | -| modified beam search | 2.59 | 6.19 | --epoch 25, --avg 8, --max-duration 600 --decoding-method modified_beam_search| +| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 25 --avg 8 --max-duration 600 | +| fast beam search | 2.61 | 6.17 | --epoch 25 --avg 8 --max-duration 600 --decoding-method fast_beam_search | +| modified beam search | 2.59 | 6.19 | --epoch 25 --avg 8 --max-duration 600 --decoding-method modified_beam_search| +| greedy search (max sym per frame 1) | 2.70 | 6.04 | --epoch 34 --avg 10 --max-duration 600 | +| fast beam search | 2.66 | 6.00 | --epoch 34 --avg 10 --max-duration 600 --decoding-method fast_beam_search | +| greedy search (max sym per frame 1) | 2.60 | 6.06 | --epoch 37 --avg 10 --max-duration 600 | +| fast beam search | 2.62 | 5.97 | --epoch 37 --avg 10 --max-duration 600 --decoding-method fast_beam_search | + + The train and decode commands are: @@ -23,7 +29,8 @@ The train and decode commands are: and: `python3 ./pruned_transducer_stateless2/decode.py --exp-dir pruned_transducer_stateless2/exp --epoch 25 --avg 8 --bpe-model ./data/lang_bpe_500/bpe.model --max-duration 600` -The Tensorboard log is at +The Tensorboard log is at (apologies, log starts +only from epoch 3). The WERs for librispeech 100 hours are: @@ -40,7 +47,9 @@ schedule is not visible due to a since-fixed bug). |-------------------------------------|------------|------------|-------------------------------------------------------| | greedy search (max sym per frame 1) | 7.12 | 18.42 | --epoch 19 --avg 8 | | greedy search (max sym per frame 1) | 6.71 | 17.77 | --epoch 29 --avg 8 | -| fast beam search | 6.58 | 17.27 | --epoch 19 --avg 8 --decoding-method fast_beam_search | +| greedy search (max sym per frame 1) | 6.64 | 17.19 | --epoch 39 --avg 10 | +| fast beam search | 6.58 | 17.27 | --epoch 29 --avg 8 --decoding-method fast_beam_search | +| fast beam search | 6.53 | 16.82 | --epoch 39 --avg 10 --decoding-method fast_beam_search | Trained with two jobs: `python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_ws2 --world-size 2 --num-epochs 40 --full-libri 0 --max-duration 300` @@ -52,9 +61,9 @@ The Tensorboard log is at . +Train command was +`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_fp16 --world-size 1 --num-epochs 40 --full-libri 0 --max-duration 500 --use-fp16 True` + +The Tensorboard log is at + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-----------------------| +| greedy search (max sym per frame 1) | 7.10 | 18.79 | --epoch 19 --avg 8 | +| greedy search (max sym per frame 1) | 6.92 | 18.16 | --epoch 29 --avg 8 | +| greedy search (max sym per frame 1) | 6.89 | 17.75 | --epoch 30 --avg 10 | + +https://tensorboard.dev/experiment/Km7QBHYnSLWs4qQnAJWsaA/ @@ -91,11 +116,11 @@ The WERs are: | | test-clean | test-other | comment | |-------------------------------------|------------|------------|------------------------------------------| -| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 | -| greedy search (max sym per frame 2) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 | -| greedy search (max sym per frame 3) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 | -| modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 42, --avg 11, --max-duration 100 | -| beam search (beam size 4) | 2.57 | 6.27 | --epoch 42, --avg 11, --max-duration 100 | +| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 42 --avg 11 --max-duration 100 | +| greedy search (max sym per frame 2) | 2.62 | 6.37 | --epoch 42 --avg 11 --max-duration 100 | +| greedy search (max sym per frame 3) | 2.62 | 6.37 | --epoch 42 --avg 11 --max-duration 100 | +| modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 42 --avg 11 --max-duration 100 | +| beam search (beam size 4) | 2.57 | 6.27 | --epoch 42 --avg 11 --max-duration 100 | @@ -189,7 +214,7 @@ The WERs are | | test-clean | test-other | comment | |---------------------------|------------|------------|------------------------------------------| -| greedy search | 2.85 | 6.98 | --epoch 28, --avg 15, --max-duration 100 | +| greedy search | 2.85 | 6.98 | --epoch 28 --avg 15 --max-duration 100 | The training command for reproducing is given below: @@ -249,8 +274,8 @@ The WERs are | | test-clean | test-other | comment | |-------------------------------------|------------|------------|------------------------------------------| -| greedy search (max sym per frame 1) | 2.64 | 6.55 | --epoch 39, --avg 15, --max-duration 100 | -| modified beam search (beam size 4) | 2.61 | 6.46 | --epoch 39, --avg 15, --max-duration 100 | +| greedy search (max sym per frame 1) | 2.64 | 6.55 | --epoch 39 --avg 15 --max-duration 100 | +| modified beam search (beam size 4) | 2.61 | 6.46 | --epoch 39 --avg 15 --max-duration 100 | The training command for reproducing is given below: @@ -319,10 +344,10 @@ The WERs are | | test-clean | test-other | comment | |-------------------------------------|------------|------------|------------------------------------------| -| greedy search (max sym per frame 1) | 2.67 | 6.67 | --epoch 63, --avg 19, --max-duration 100 | -| greedy search (max sym per frame 2) | 2.67 | 6.67 | --epoch 63, --avg 19, --max-duration 100 | -| greedy search (max sym per frame 3) | 2.67 | 6.67 | --epoch 63, --avg 19, --max-duration 100 | -| modified beam search (beam size 4) | 2.67 | 6.57 | --epoch 63, --avg 19, --max-duration 100 | +| greedy search (max sym per frame 1) | 2.67 | 6.67 | --epoch 63 --avg 19 --max-duration 100 | +| greedy search (max sym per frame 2) | 2.67 | 6.67 | --epoch 63 --avg 19 --max-duration 100 | +| greedy search (max sym per frame 3) | 2.67 | 6.67 | --epoch 63 --avg 19 --max-duration 100 | +| modified beam search (beam size 4) | 2.67 | 6.57 | --epoch 63 --avg 19 --max-duration 100 | The training command for reproducing is given below: From d0a53aad487ff24dc1fca256346cc3350239cfff Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 12 Apr 2022 11:51:15 +0800 Subject: [PATCH 213/234] Fix tensorboard log location --- egs/librispeech/ASR/RESULTS.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 645e24fdc..9f47ac495 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -87,7 +87,7 @@ floats and max-duration increased from 300 to 500, after merging +The Tensorboard log is at | | test-clean | test-other | comment | |-------------------------------------|------------|------------|-----------------------| @@ -95,7 +95,6 @@ The Tensorboard log is at Date: Tue, 12 Apr 2022 12:20:10 +0800 Subject: [PATCH 214/234] Add one more epoch of full expt --- egs/librispeech/ASR/RESULTS.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 9f47ac495..01637beb1 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -14,12 +14,12 @@ The WERs are: | | test-clean | test-other | comment | |-------------------------------------|------------|------------|-------------------------------------------------------------------------------| | greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 25 --avg 8 --max-duration 600 | -| fast beam search | 2.61 | 6.17 | --epoch 25 --avg 8 --max-duration 600 --decoding-method fast_beam_search | -| modified beam search | 2.59 | 6.19 | --epoch 25 --avg 8 --max-duration 600 --decoding-method modified_beam_search| +| fast beam search | 2.61 | 6.17 | --epoch 25 --avg 8 --max-duration 600 --decoding-method fast_beam_search | +| modified beam search | 2.59 | 6.19 | --epoch 25 --avg 8 --max-duration 600 --decoding-method modified_beam_search | | greedy search (max sym per frame 1) | 2.70 | 6.04 | --epoch 34 --avg 10 --max-duration 600 | -| fast beam search | 2.66 | 6.00 | --epoch 34 --avg 10 --max-duration 600 --decoding-method fast_beam_search | -| greedy search (max sym per frame 1) | 2.60 | 6.06 | --epoch 37 --avg 10 --max-duration 600 | -| fast beam search | 2.62 | 5.97 | --epoch 37 --avg 10 --max-duration 600 --decoding-method fast_beam_search | +| fast beam search | 2.66 | 6.00 | --epoch 34 --avg 10 --max-duration 600 --decoding-method fast_beam_search | +| greedy search (max sym per frame 1) | 2.62 | 6.03 | --epoch 38 --avg 10 --max-duration 600 | +| fast beam search | 2.57 | 5.95 | --epoch 38 --avg 10 --max-duration 600 --decoding-method fast_beam_search | From c2808f8541371d55e1cbd9d3129d0851c7295d92 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 12 Apr 2022 20:13:51 +0800 Subject: [PATCH 215/234] Support cache of left context for causal convolution. --- .../ASR/conv_emformer_transducer/emformer.py | 153 +++++++++++++----- .../conv_emformer_transducer/test_emformer.py | 32 ++-- .../ASR/conv_emformer_transducer/train.py | 2 +- 3 files changed, 134 insertions(+), 53 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index 5ac65141e..e9ce56aa7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -601,24 +601,8 @@ class EmformerLayer(nn.Module): ) return right_context_utterance - def _apply_conv_module( - self, - right_context_utterance: torch.Tensor, - right_context_end_idx: int, - ) -> torch.Tensor: - """Apply convolution module on utterance.""" - utterance = right_context_utterance[right_context_end_idx:] - right_context = right_context_utterance[:right_context_end_idx] - - residual = utterance - utterance = self.norm_conv(utterance) - utterance = residual + self.dropout(self.conv_module(utterance)) - right_context_utterance = torch.cat([right_context, utterance]) - return right_context_utterance - def _apply_feed_forward_module( - self, - right_context_utterance: torch.Tensor, + self, right_context_utterance: torch.Tensor ) -> torch.Tensor: """Apply feed forward module.""" residual = right_context_utterance @@ -628,6 +612,39 @@ class EmformerLayer(nn.Module): ) return right_context_utterance + def _apply_conv_module_forward( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + ) -> torch.Tensor: + """Apply convolution module on utterance in non-infer mode.""" + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + residual = utterance + utterance = self.norm_conv(utterance) + utterance, _ = self.conv_module(utterance) + utterance = residual + self.dropout(utterance) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance + + def _apply_conv_module_infer( + self, + right_context_utterance: torch.Tensor, + right_context_end_idx: int, + conv_cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply convolution module on utterance in infer mode.""" + utterance = right_context_utterance[right_context_end_idx:] + right_context = right_context_utterance[:right_context_end_idx] + + residual = utterance + utterance = self.norm_conv(utterance) + utterance, conv_cache = self.conv_module(utterance, conv_cache) + utterance = residual + self.dropout(utterance) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance, conv_cache + def _apply_attention_module_forward( self, right_context_utterance: torch.Tensor, @@ -790,7 +807,7 @@ class EmformerLayer(nn.Module): attention_mask, ) - right_context_utterance = self._apply_conv_module( + right_context_utterance = self._apply_conv_module_forward( right_context_utterance, right_context_end_idx ) @@ -812,6 +829,7 @@ class EmformerLayer(nn.Module): right_context: torch.Tensor, memory: torch.Tensor, state: Optional[List[torch.Tensor]] = None, + conv_cache: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: """Forward pass for inference. @@ -841,6 +859,8 @@ class EmformerLayer(nn.Module): state (List[torch.Tensor], optional): List of tensors representing layer internal state generated in preceding computation. (default=None) + conv_cache (torch.Tensor, optional): + Cache tensor of left context for causal convolution. Returns: (Tensor, Tensor, List[torch.Tensor], Tensor): @@ -848,6 +868,7 @@ class EmformerLayer(nn.Module): - output right_context, with shape (R, B, D); - output memory, with shape (1, B, D) or (0, B, D). - output state. + - updated conv_cache. """ right_context_utterance = torch.cat([right_context, utterance]) right_context_end_idx = right_context.size(0) @@ -868,8 +889,10 @@ class EmformerLayer(nn.Module): state, ) - right_context_utterance = self._apply_conv_module( - right_context_utterance, right_context_end_idx + right_context_utterance, conv_cache = self._apply_conv_module_infer( + right_context_utterance, + right_context_end_idx, + conv_cache, ) right_context_utterance = self._apply_feed_forward_module( @@ -885,6 +908,7 @@ class EmformerLayer(nn.Module): output_right_context, output_memory, output_state, + conv_cache, ) @@ -1156,7 +1180,10 @@ class EmformerEncoder(nn.Module): x: torch.Tensor, lengths: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: + conv_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[ + torch.Tensor, torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] + ]: """Forward pass for streaming inference. B: batch size; @@ -1173,15 +1200,18 @@ class EmformerEncoder(nn.Module): right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each - element (List[torch.Tensor]) corresponding to each emformer layer. + element (List[torch.Tensor]) corresponds to each emformer layer. (default: None) - + conv_caches (List[torch.Tensor], optional): + Cached tensors of left context for causal convolution, where each + element (Tensor) corresponds to each convolutional layer. Returns: - (Tensor, Tensor, List[List[torch.Tensor]]): + (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): - output utterance frames, with shape (U, B, D). - output lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. + - updated convolution caches from current chunk. """ assert x.size(0) == self.chunk_length + self.right_context_length, ( "Per configured chunk_length and right_context_length, " @@ -1199,17 +1229,26 @@ class EmformerEncoder(nn.Module): ) output = utterance output_states: List[List[torch.Tensor]] = [] + output_conv_caches: List[torch.Tensor] = [] for layer_idx, layer in enumerate(self.emformer_layers): - output, right_context, memory, output_state = layer.infer( + ( + output, + right_context, + memory, + output_state, + output_conv_cache, + ) = layer.infer( output, output_lengths, right_context, memory, None if states is None else states[layer_idx], + None if conv_caches is None else conv_caches[layer_idx], ) output_states.append(output_state) + output_conv_caches.append(output_conv_cache) - return output, output_lengths, output_states + return output, output_lengths, output_states, output_conv_caches class Emformer(EncoderInterface): @@ -1328,6 +1367,7 @@ class Emformer(EncoderInterface): x: torch.Tensor, x_lens: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, + conv_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: """Forward pass for streaming inference. @@ -1345,8 +1385,11 @@ class Emformer(EncoderInterface): right_context at the end. states (List[List[torch.Tensor]], optional): Cached states from proceeding chunk's computation, where each - element (List[torch.Tensor]) corresponding to each emformer layer. + element (List[torch.Tensor]) corresponds to each emformer layer. (default: None) + conv_caches (List[torch.Tensor], optional): + Cached tensors of left context for causal convolution, where each + element (Tensor) corresponds to each convolutional layer. Returns: (Tensor, Tensor): - output logits, with shape (B, T', D), where @@ -1354,6 +1397,7 @@ class Emformer(EncoderInterface): - logits lengths, with shape (B,), without containing the right_context at the end. - updated states from current chunk's computation. + - updated convolution caches from current chunk. """ x = self.encoder_embed(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -1364,14 +1408,17 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths, output_states = self.encoder.infer( - x, x_lens, states - ) # (T, N, C) + ( + output, + output_lengths, + output_states, + output_conv_caches, + ) = self.encoder.infer(x, x_lens, states, conv_caches) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, output_lengths, output_states + return logits, output_lengths, output_states, output_conv_caches class ConvolutionModule(nn.Module): @@ -1437,28 +1484,50 @@ class ConvolutionModule(nn.Module): ) self.activation = Swish() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Compute convolution module. Args: - x: Input tensor (#time, batch, channels). - + x (torch.Tensor): + Input tensor (#time, batch, channels). + cache (torch.Tensor, optional): + Cached tensor for left padding (#batch, channels, cache_time). Returns: - Tensor: Output tensor (#time, batch, channels). - + A tuple of 2 tensors: + - output tensor (#time, batch, channels). + - updated cache tensor (#batch, channels, cache_time). """ # exchange the temporal dimension and the feature dimension x = x.permute(1, 2, 0) # (#batch, channels, time). - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - # 1D Depthwise Conv if self.left_padding > 0: # manualy padding self.lorder zeros to the left # make depthwise_conv causal - x = nn.functional.pad(x, (self.left_padding, 0), "constant", 0.0) + if cache is None: + x = nn.functional.pad( + x, (self.left_padding, 0), "constant", 0.0 + ) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + assert cache.size(2) == self.left_padding + x = torch.cat([cache, x], dim=2) + new_cache = x[:, :, x.size(2) - self.left_padding :] # noqa + else: + # It's better we just return None if no cache is requried, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = None + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + x = self.depthwise_conv(x) # x is (batch, channels, time) x = x.permute(0, 2, 1) @@ -1469,7 +1538,7 @@ class ConvolutionModule(nn.Module): x = self.pointwise_conv2(x) # (batch, channel, time) - return x.permute(2, 0, 1) + return x.permute(2, 0, 1), new_cache class Swish(torch.nn.Module): diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py index 7685bfb26..41e911e17 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -133,6 +133,7 @@ def test_emformer_layer_infer(): R, L = 2, 5 chunk_length = 2 U = chunk_length + K = 3 for use_memory in [True, False]: if use_memory: @@ -145,7 +146,7 @@ def test_emformer_layer_infer(): nhead=8, dim_feedforward=1024, chunk_length=chunk_length, - cnn_module_kernel=3, + cnn_module_kernel=K, left_context_length=L, max_memory_size=M, causal=True, @@ -157,17 +158,15 @@ def test_emformer_layer_infer(): right_context = torch.randn(R, B, D) memory = torch.randn(M, B, D) state = None + conv_cache = None ( output_utterance, output_right_context, output_memory, output_state, + output_conv_cache, ) = layer.infer( - utterance, - lengths, - right_context, - memory, - state, + utterance, lengths, right_context, memory, state, conv_cache ) assert output_utterance.shape == (U, B, D) assert output_right_context.shape == (R, B, D) @@ -180,6 +179,7 @@ def test_emformer_layer_infer(): assert output_state[1].shape == (L, B, D) assert output_state[2].shape == (L, B, D) assert output_state[3].shape == (1, B) + assert output_conv_cache.shape == (B, D, K - 1) def test_emformer_encoder_forward(): @@ -226,6 +226,7 @@ def test_emformer_encoder_infer(): U = chunk_length num_chunks = 3 num_encoder_layers = 2 + K = 3 for use_memory in [True, False]: if use_memory: @@ -238,7 +239,7 @@ def test_emformer_encoder_infer(): d_model=D, dim_feedforward=1024, num_encoder_layers=num_encoder_layers, - cnn_module_kernel=3, + cnn_module_kernel=K, left_context_length=L, right_context_length=R, max_memory_size=M, @@ -246,11 +247,14 @@ def test_emformer_encoder_infer(): ) states = None + conv_caches = None for chunk_idx in range(num_chunks): x = torch.randn(U + R, B, D) lengths = torch.randint(1, U + R + 1, (B,)) lengths[0] = U + R - output, output_lengths, states = encoder.infer(x, lengths, states) + output, output_lengths, states, conv_caches = encoder.infer( + x, lengths, states, conv_caches + ) assert output.shape == (U, B, D) assert torch.equal(output_lengths, torch.clamp(lengths - R, min=0)) assert len(states) == num_encoder_layers @@ -262,6 +266,8 @@ def test_emformer_encoder_infer(): assert torch.equal( state[3], (chunk_idx + 1) * U * torch.ones_like(state[3]) ) + for conv_cache in conv_caches: + assert conv_cache.shape == (B, D, K - 1) def test_emformer_forward(): @@ -312,6 +318,7 @@ def test_emformer_infer(): B, D = 2, 256 num_chunks = 3 num_encoder_layers = 2 + K = 3 for use_memory in [True, False]: if use_memory: M = 3 @@ -324,7 +331,7 @@ def test_emformer_infer(): subsampling_factor=4, d_model=D, num_encoder_layers=num_encoder_layers, - cnn_module_kernel=3, + cnn_module_kernel=K, left_context_length=L, right_context_length=R, max_memory_size=M, @@ -332,11 +339,14 @@ def test_emformer_infer(): causal=True, ) states = None + conv_caches = None for chunk_idx in range(num_chunks): x = torch.randn(B, U + R + 3, num_features) x_lens = torch.randint(1, U + R + 3 + 1, (B,)) x_lens[0] = U + R + 3 - logits, output_lengths, states = model.infer(x, x_lens, states) + logits, output_lengths, states, conv_caches = model.infer( + x, x_lens, states, conv_caches + ) assert logits.shape == (B, U // 4, output_dim) assert torch.equal( output_lengths, @@ -352,6 +362,8 @@ def test_emformer_infer(): state[3], U // 4 * (chunk_idx + 1) * torch.ones_like(state[3]), ) + for conv_cache in conv_caches: + assert conv_cache.shape == (B, D, K - 1) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/conv_emformer_transducer/train.py b/egs/librispeech/ASR/conv_emformer_transducer/train.py index d0126bb94..8a0eecc6b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/train.py @@ -139,7 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--causal-conv", - type=bool, + type=str2bool, default=True, help="Whether use causal convolution.", ) From 78418ac37cdcfa1c4d0f54fe77901f74644ff96a Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 13 Apr 2022 13:09:24 +0800 Subject: [PATCH 216/234] fix comments --- egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py | 4 ++-- icefall/diagnostics.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index d59aa2160..f89d2963e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -367,7 +367,7 @@ class ActivationBalancer(torch.nn.Module): min_positive: the minimum, per channel, of the proportion of the time that (x > 0), below which we start to modify the derivatives. max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. + that (x > 0), above which we start to modify the derivatives. max_factor: the maximum factor by which we modify the derivatives for either the sign constraint or the magnitude constraint; e.g. with max_factor=0.02, the the derivatives would be multiplied by @@ -413,7 +413,7 @@ class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) This is a definition, originally motivated by its close numerical - similarity to swish(swish(x), where swish(x) = x * sigmoid(x). + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). Memory-efficient derivative computation: double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index ce4ac1464..bc8fe3069 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -111,7 +111,7 @@ def get_diagnostics_for_dim( options object sizes_same: True if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", + stats_type: either "abs" or "positive" or "eigs" or "value", imdictates the type of stats we accumulate, abs is mean absolute value, "positive" is proportion of positive to nonnegative values, "eigs" is eigenvalues after doing outer product on this dim, sum From af6ae840ee4a30f67eb3dafb30fd2aeb51e41768 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 13 Apr 2022 20:22:11 +0800 Subject: [PATCH 217/234] Add results for mixed precision with max-duration 300 --- egs/librispeech/ASR/RESULTS.md | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 01637beb1..3488535a6 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -2,9 +2,10 @@ ### LibriSpeech BPE training results (Pruned Transducer 2) +[pruned_transducer_stateless2](./pruned_transducer_stateless2) This is with a reworked version of the conformer encoder, with many changes. -[pruned_transducer_stateless2](./pruned_transducer_stateless2) +#### Training on fulll librispeech using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`. See @@ -33,9 +34,9 @@ The Tensorboard log is at . +Train command was +`python3 ./pruned_transducer_stateless2/train.py --exp-dir=pruned_transducer_stateless2/exp_100h_fp16 --world-size 1 --num-epochs 40 --full-libri 0 --max-duration 300 --use-fp16 True` + +The Tensorboard log is at + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|-----------------------| +| greedy search (max sym per frame 1) | 7.10 | 18.57 | --epoch 19 --avg 8 | +| greedy search (max sym per frame 1) | 6.81 | 17.84 | --epoch 29 --avg 8 | +| greedy search (max sym per frame 1) | 6.63 | 17.39 | --epoch 30 --avg 10 | + + Trained with 1 job, with --use-fp16=True --max-duration=500, i.e. with half-precision floats and max-duration increased from 300 to 500, after merging . Train command was From 4130892971d32abc58c0bb195eaf667713864e9c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 13 Apr 2022 23:46:42 +0800 Subject: [PATCH 218/234] delete duplicated dropout in emformer attention and update emformer test codes. --- .../emformer.py | 13 +- .../test_emformer.py | 222 +++++++++++++++++- 2 files changed, 220 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 4ba19ebae..67e9f5891 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -183,9 +183,9 @@ class EmformerAttention(nn.Module): attention_probs = nn.functional.softmax( attention_weights_float, dim=-1 ).type_as(attention_weights) - attention_probs = nn.functional.dropout( - attention_probs, p=float(self.dropout), training=self.training - ) + # attention_probs = nn.functional.dropout( + # attention_probs, p=float(self.dropout), training=self.training + # ) return attention_probs def _forward_impl( @@ -955,16 +955,15 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_segs = math.ceil( + num_chunks = math.ceil( (T - self.right_context_length) / self.chunk_length ) right_context_blocks = [] - for seg_idx in range(num_segs - 1): + for seg_idx in range(num_chunks - 1): start = (seg_idx + 1) * self.chunk_length end = start + self.right_context_length right_context_blocks.append(x[start:end]) - last_right_context_start_idx = T - self.right_context_length - right_context_blocks.append(x[last_right_context_start_idx:]) + right_context_blocks.append(x[T - self.right_context_length :]) # noqa return torch.cat(right_context_blocks) def _gen_attention_mask_col_widths( diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 56cf2035e..5e08640d3 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -342,12 +342,218 @@ def test_emformer_infer(): ) +def test_emformer_attention_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 1 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.0, + ) + encoder_layer = encoder.emformer_layers[0] + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U]) + right_context = encoder._gen_right_context(x) + utterance = x[: x.size(0) - R] + attention_mask = encoder._gen_attention_mask(utterance) + memory = ( + encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + forward_output_right_context_utterance, + forward_output_memory, + ) = encoder_layer._apply_attention_forward( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + forward_output_utterance = forward_output_right_context_utterance[ + right_context.size(0) : # noqa + ] + + state = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx:end_idx] + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + chunk_memory = ( + encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + infer_output_right_context_utterance, + infer_output_memory, + state, + ) = encoder_layer._apply_attention_infer( + chunk, + chunk_length, + chunk_right_context, + chunk_memory, + state, + ) + infer_output_utterance = infer_output_right_context_utterance[ + chunk_right_context.size(0) : # noqa + ] + print( + infer_output_utterance + - forward_output_utterance[start_idx:end_idx] + ) + + +def test_emformer_layer_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 1 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.0, + ) + encoder_layer = encoder.emformer_layers[0] + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U]) + right_context = encoder._gen_right_context(x) + utterance = x[: x.size(0) - R] + attention_mask = encoder._gen_attention_mask(utterance) + memory = ( + encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + forward_output_utterance, + forward_output_right_context, + forward_output_memory, + ) = encoder_layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + + state = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx:end_idx] + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + chunk_memory = ( + encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + infer_output_utterance, + infer_right_context, + infer_output_memory, + state, + ) = encoder_layer.infer( + chunk, + chunk_length, + chunk_right_context, + chunk_memory, + state, + ) + print( + infer_output_utterance + - forward_output_utterance[start_idx:end_idx] + ) + + +def test_emformer_encoder_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 3 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.0, + ) + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U + R]) + + forward_output, forward_output_lengths = encoder(x, lengths) + + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx : end_idx + R] # noqa + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + infer_output, infer_output_lengths, states = encoder.infer( + chunk, + chunk_length, + states, + ) + print(infer_output - forward_output[start_idx:end_idx]) + + if __name__ == "__main__": - test_emformer_attention_forward() - test_emformer_attention_infer() - test_emformer_layer_forward() - test_emformer_layer_infer() - test_emformer_encoder_forward() - test_emformer_encoder_infer() - test_emformer_forward() - test_emformer_infer() + # test_emformer_attention_forward() + # test_emformer_attention_infer() + # test_emformer_layer_forward() + # test_emformer_layer_infer() + # test_emformer_encoder_forward() + # test_emformer_encoder_infer() + # test_emformer_forward() + # test_emformer_infer() + # test_emformer_attention_forward_infer_consistency() + # test_emformer_layer_forward_infer_consistency() + test_emformer_encoder_forward_infer_consistency() From d88e7865138a8542d47e8c3a56c79ac66fc313d9 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Thu, 14 Apr 2022 09:54:07 +0800 Subject: [PATCH 219/234] Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) --- .../beam_search.py | 25 +- .../beam_search.py | 747 +----------------- .../pruned_transducer_stateless/pretrained.py | 138 +++- 3 files changed, 128 insertions(+), 782 deletions(-) mode change 100644 => 120000 egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 815e1c02a..ef1f399c6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -86,7 +87,12 @@ def fast_beam_search( # (shape.NumElements(), 1, encoder_out_dim) # fmt: off current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).long() + # in some old versions of pytorch, the type of index requires + # to be LongTensor. In the newest version of pytorch, the type + # of index can be IntTensor or LongTensor. For supporting the + # old and new versions of pytorch, we set the type of index + # to LongTensor. ) # fmt: on logits = model.joiner( @@ -124,6 +130,7 @@ def greedy_search( assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size device = model.device @@ -160,7 +167,7 @@ def greedy_search( # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() - if y != blank_id: + if y != blank_id and y != unk_id: hyp.append(y) decoder_input = torch.tensor( [hyp[-context_size:]], device=device @@ -200,6 +207,7 @@ def greedy_search_batch( T = encoder_out.size(1) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size hyps = [[blank_id] * context_size for _ in range(batch_size)] @@ -223,7 +231,7 @@ def greedy_search_batch( y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): - if v != blank_id: + if v != blank_id and v != unk_id: hyps[i].append(v) emitted = True if emitted: @@ -415,6 +423,7 @@ def modified_beam_search( T = encoder_out.size(1) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size device = model.device B = [HypothesisList() for _ in range(batch_size)] @@ -491,7 +500,7 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - if new_token != blank_id: + if new_token != blank_id and new_token != unk_id: new_ys.append(new_token) new_log_prob = topk_log_probs[k] @@ -532,6 +541,7 @@ def _deprecated_modified_beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size device = model.device @@ -597,7 +607,7 @@ def _deprecated_modified_beam_search( hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] new_token = topk_token_indexes[i] - if new_token != blank_id: + if new_token != blank_id and new_token != unk_id: new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) @@ -634,6 +644,7 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = model.decoder.unk_id context_size = model.decoder.context_size device = model.device @@ -714,7 +725,7 @@ def beam_search( # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) for i, v in zip(indices.tolist(), values.tolist()): - if i == blank_id: + if i == blank_id or i == unk_id: continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py deleted file mode 100644 index 3a08b100d..000000000 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Dict, List, Optional - -import k2 -import torch -from model import Transducer - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def fast_beam_search( - model: Transducer, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, -) -> List[List[int]]: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - unk_id = model.decoder.unk_id - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - # current_encoder_out is of shape - # (shape.NumElements(), 1, encoder_out_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - new_hyps = [] - for hyp in hyps: - hyp = [idx for idx in hyp if idx != unk_id] - new_hyps.append(hyp) - return new_hyps - - -def greedy_search( - model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int -) -> List[int]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor( - [blank_id] * context_size, device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y != blank_id and y != unk_id: - hyp.append(y) - decoder_input = torch.tensor( - [hyp[-context_size:]], device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - return hyp - - -def greedy_search_batch( - model: Transducer, encoder_out: torch.Tensor -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - Returns: - Return a list-of-list integers containing the decoded results. - len(ans) equals to encoder_out.size(0). - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - device = model.device - - batch_size = encoder_out.size(0) - T = encoder_out.size(1) - - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - hyps = [[blank_id] * context_size for _ in range(batch_size)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (batch_size, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - # decoder_out: (batch_size, 1, decoder_out_dim) - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id and v != unk_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps] - decoder_input = torch.tensor(decoder_input, device=device) - decoder_out = model.decoder(decoder_input, need_pad=False) - - ans = [h[context_size:] for h in hyps] - return ans - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - beam: - Number of active paths during the beam search. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - - batch_size = encoder_out.size(0) - T = encoder_out.size(1) - - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - device = model.device - B = [HypothesisList() for _ in range(batch_size)] - for i in range(batch_size): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = _get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - topk_hyp_indexes = torch.div( - topk_indexes, vocab_size, rounding_mode="trunc" - ) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id and new_token != unk_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - ans = [h.ys[context_size:] for h in best_hyps] - - return ans - - -def _deprecated_modified_beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - device = model.device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token != blank_id and new_token != unk_id: - new_ys.append(new_token) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - return ys - - -def beam_search( - model: Transducer, - encoder_out: torch.Tensor, - beam: int = 4, -) -> List[int]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - Returns: - Return the decoded result. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = model.decoder.unk_id - context_size = model.decoder.context_size - - device = model.device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1) - ) - - # TODO(fangjun): Scale the blank posterior - - log_prob = logits.log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i == blank_id or i == unk_id: - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - return ys diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py new file mode 120000 index 000000000..7f9f6263f --- /dev/null +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 2c795ede0..08e4962e2 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -36,7 +36,6 @@ Usage: /path/to/foo.wav \ /path/to/bar.wav - (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ @@ -46,6 +45,17 @@ Usage: /path/to/foo.wav \ /path/to/bar.wav +(4) fast beam search +./pruned_transducer_stateless/pretrained.py \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + /path/to/foo.wav \ + /path/to/bar.wav + You can also use `./pruned_transducer_stateless/exp/epoch-xx.pt`. Note: ./pruned_transducer_stateless/exp/pretrained.pt is generated by @@ -58,12 +68,19 @@ import logging import math from typing import List +import k2 import kaldifeat import sentencepiece as spm import torch import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + fast_beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -97,12 +114,14 @@ def get_parser(): ) parser.add_argument( - "--method", + "--decoding-method", type=str, default="greedy_search", help="""Possible values are: - greedy_search - beam_search + - modified_beam_search + - fast_beam_search """, ) @@ -123,6 +142,32 @@ def get_parser(): help="Used only when --method is beam_search and modified_beam_search ", ) + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + parser.add_argument( "--context-size", type=int, @@ -134,7 +179,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -268,6 +313,11 @@ def main(): model.eval() model.device = device + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -299,34 +349,64 @@ def main(): x=features, x_lens=feature_lengths ) - num_waves = encoder_out.size(0) hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" + msg = f"Using {params.decoding_method}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) s = "\n" for filename, hyp in zip(params.sound_files, hyps): From 5fe58de43c970a940a2843406007d160934bf504 Mon Sep 17 00:00:00 2001 From: "Wang, Guanbo" Date: Thu, 14 Apr 2022 04:07:22 -0400 Subject: [PATCH 220/234] GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang --- .flake8 | 1 + .gitignore | 2 + egs/gigaspeech/ASR/.gitignore | 1 + egs/gigaspeech/ASR/README.md | 20 + egs/gigaspeech/ASR/RESULTS.md | 79 ++ egs/gigaspeech/ASR/conformer_ctc/__init__.py | 0 .../ASR/conformer_ctc/asr_datamodule.py | 373 +++++++ egs/gigaspeech/ASR/conformer_ctc/conformer.py | 930 +++++++++++++++++ egs/gigaspeech/ASR/conformer_ctc/decode.py | 715 +++++++++++++ .../ASR/conformer_ctc/gigaspeech_scoring.py | 115 +++ .../ASR/conformer_ctc/label_smoothing.py | 98 ++ .../ASR/conformer_ctc/subsampling.py | 161 +++ egs/gigaspeech/ASR/conformer_ctc/train.py | 737 ++++++++++++++ .../ASR/conformer_ctc/transformer.py | 953 ++++++++++++++++++ egs/gigaspeech/ASR/local/__init__.py | 0 egs/gigaspeech/ASR/local/compile_hlg.py | 1 + .../compute_fbank_gigaspeech_dev_test.py | 92 ++ .../local/compute_fbank_gigaspeech_splits.py | 165 +++ .../ASR/local/compute_fbank_musan.py | 103 ++ .../convert_transcript_words_to_tokens.py | 1 + egs/gigaspeech/ASR/local/prepare_lang.py | 1 + egs/gigaspeech/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_gigaspeech.py | 113 +++ egs/gigaspeech/ASR/local/train_bpe_model.py | 1 + egs/gigaspeech/ASR/prepare.sh | 325 ++++++ egs/gigaspeech/ASR/shared | 1 + icefall/decode.py | 76 +- 27 files changed, 5049 insertions(+), 16 deletions(-) create mode 100644 egs/gigaspeech/ASR/.gitignore create mode 100644 egs/gigaspeech/ASR/README.md create mode 100644 egs/gigaspeech/ASR/RESULTS.md create mode 100644 egs/gigaspeech/ASR/conformer_ctc/__init__.py create mode 100644 egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py create mode 100644 egs/gigaspeech/ASR/conformer_ctc/conformer.py create mode 100755 egs/gigaspeech/ASR/conformer_ctc/decode.py create mode 100755 egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py create mode 100644 egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py create mode 100644 egs/gigaspeech/ASR/conformer_ctc/subsampling.py create mode 100755 egs/gigaspeech/ASR/conformer_ctc/train.py create mode 100644 egs/gigaspeech/ASR/conformer_ctc/transformer.py create mode 100644 egs/gigaspeech/ASR/local/__init__.py create mode 120000 egs/gigaspeech/ASR/local/compile_hlg.py create mode 100755 egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py create mode 100755 egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py create mode 100755 egs/gigaspeech/ASR/local/compute_fbank_musan.py create mode 120000 egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py create mode 120000 egs/gigaspeech/ASR/local/prepare_lang.py create mode 120000 egs/gigaspeech/ASR/local/prepare_lang_bpe.py create mode 100755 egs/gigaspeech/ASR/local/preprocess_gigaspeech.py create mode 120000 egs/gigaspeech/ASR/local/train_bpe_model.py create mode 100755 egs/gigaspeech/ASR/prepare.sh create mode 120000 egs/gigaspeech/ASR/shared diff --git a/.flake8 b/.flake8 index 5b3c444b8..cd55ded73 100644 --- a/.flake8 +++ b/.flake8 @@ -7,6 +7,7 @@ per-file-ignores = egs/librispeech/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501, egs/tedlium3/ASR/*/conformer.py: E501, + egs/gigaspeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501, # invalid escape sequence (cause by tex formular), W605 diff --git a/.gitignore b/.gitignore index 870d3cea3..1dbf8f395 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ exp exp*/ *.pt download +dask-worker-space +log *.bak *-bak *bak.py diff --git a/egs/gigaspeech/ASR/.gitignore b/egs/gigaspeech/ASR/.gitignore new file mode 100644 index 000000000..5592679cc --- /dev/null +++ b/egs/gigaspeech/ASR/.gitignore @@ -0,0 +1 @@ +log-* diff --git a/egs/gigaspeech/ASR/README.md b/egs/gigaspeech/ASR/README.md new file mode 100644 index 000000000..7796ef2a0 --- /dev/null +++ b/egs/gigaspeech/ASR/README.md @@ -0,0 +1,20 @@ +# GigaSpeech +GigaSpeech, an evolving, multi-domain English +speech recognition corpus with 10,000 hours of high quality labeled +audio, collected from audiobooks, podcasts +and YouTube, covering both read and spontaneous speaking styles, +and a variety of topics, such as arts, science, sports, etc. More details can be found: https://github.com/SpeechColab/GigaSpeech + +## Download + +Apply for the download credentials and download the dataset by following https://github.com/SpeechColab/GigaSpeech#download. Then create a symlink +```bash +ln -sfv /path/to/GigaSpeech download/GigaSpeech +``` + +## Performance Record +| | Dev | Test | +|-----|-------|-------| +| WER | 10.47 | 10.58 | + +See [RESULTS](/egs/gigaspeech/ASR/RESULTS.md) for details. diff --git a/egs/gigaspeech/ASR/RESULTS.md b/egs/gigaspeech/ASR/RESULTS.md new file mode 100644 index 000000000..b29e893da --- /dev/null +++ b/egs/gigaspeech/ASR/RESULTS.md @@ -0,0 +1,79 @@ +## Results + +### GigaSpeech BPE training results (Conformer-CTC) + +#### 2022-04-06 + +The best WER, as of 2022-04-06, for the gigaspeech is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | Dev | Test | +|-----|-------|-------| +| WER | 10.47 | 10.58 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.5 | 1.3 | + + +To reproduce the above result, use the following commands for training: + +``` +cd egs/gigaspeech/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 1 \ + --world-size 8 \ + --exp-dir conformer_ctc/exp_500 \ + --lang-dir data/lang_bpe_500 +``` + +and the following command for decoding: + +``` +./conformer_ctc/decode.py \ + --epoch 18 \ + --avg 6 \ + --method attention-decoder \ + --num-paths 1000 \ + --exp-dir conformer_ctc/exp_500 \ + --lang-dir data/lang_bpe_500 \ + --max-duration 20 \ + --num-workers 1 +``` + +Results using HLG decoding + whole lattice rescoring: + +| | Dev | Test | +|-----|-------|-------| +| WER | 10.51 | 10.62 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: +| lm_scale | +|----------| +| 0.2 | + +To reproduce the above result, use the training commands above, and the following command for decoding: + +``` +./conformer_ctc/decode.py \ + --epoch 18 \ + --avg 6 \ + --method whole-lattice-rescoring \ + --num-paths 1000 \ + --exp-dir conformer_ctc/exp_500 \ + --lang-dir data/lang_bpe_500 \ + --max-duration 20 \ + --num-workers 1 +``` +Note: the `whole-lattice-rescoring` method is about twice as fast as the `attention-decoder` method, with slightly worse WER. + +Pretrained model is available at + + +The tensorboard log for training is available at + diff --git a/egs/gigaspeech/ASR/conformer_ctc/__init__.py b/egs/gigaspeech/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..ab958fa68 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,373 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev " + "(speeds up training)", + ) + + def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train_{self.args.subset} cuts") + path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py new file mode 100644 index 000000000..871712a46 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -0,0 +1,930 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + use_conv_batchnorm = True + if isinstance(use_feat_batchnorm, float): + use_conv_batchnorm = False + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + use_conv_batchnorm, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + if mask is not None: + mask = mask.to(x.device) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + use_conv_batchnorm: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm + ) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + use_batchnorm: bool = False, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + self.use_batchnorm = use_batchnorm + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + if self.use_batchnorm: + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_batchnorm: + x = self.norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..a810bef06 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -0,0 +1,715 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from conformer import Conformer +from gigaspeech_scoring import asr_text_post_processing + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=0, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=1000, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def post_processing( + results: List[Tuple[List[str], List[str]]], +) -> List[Tuple[List[str], List[str]]]: + new_results = [] + for ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)) + new_hyp = asr_text_post_processing(" ".join(hyp)) + new_results.append((new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + results = post_processing(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py new file mode 100755 index 000000000..ef53b77f8 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", +] +unk_tags = ["", ""] +gigaspeech_punctuations = [ + "", + "", + "", + "", +] +gigaspeech_garbage_utterance_tags = ["", "", "", ""] +non_scoring_words = ( + conversational_filler + + unk_tags + + gigaspeech_punctuations + + gigaspeech_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove hyphen + # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART" + text = text.replace("-", " ") + + # 3. remove non-scoring words from evaluation + remaining_words = [] + for word in text.split(): + if word in non_scoring_words: + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py new file mode 100644 index 000000000..cdc85ce9a --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1,98 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0 + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + target[ignored] = 0 + + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes + ) + # Set the value of ignored indexes to 0 + true_dist[ignored] = 0 + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py new file mode 100644 index 000000000..542fb0364 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -0,0 +1,161 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..2965cde18 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from conformer import Conformer +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.7, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 2000, + "valid_interval": 30000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 100000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + GigaSpeech = GigaSpeechAsrDataModule(args) + + train_cuts = GigaSpeech.train_cuts() + train_dl = GigaSpeech.train_dataloaders(train_cuts) + + valid_cuts = GigaSpeech.dev_cuts() + valid_dl = GigaSpeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py new file mode 100644 index 000000000..00ca027a7 --- /dev/null +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -0,0 +1,953 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling, VggSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + use_feat_batchnorm: + True to use batchnorm for the input layer. + Float value to scale the input layer. + False to do nothing. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + assert isinstance(use_feat_batchnorm, (float, bool)) + if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + if isinstance(self.use_feat_batchnorm, float): + x *= self.use_feat_batchnorm + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/gigaspeech/ASR/local/__init__.py b/egs/gigaspeech/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/gigaspeech/ASR/local/compile_hlg.py b/egs/gigaspeech/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/gigaspeech/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py new file mode 100755 index 000000000..9f1039893 --- /dev/null +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_gigaspeech_dev_test(): + in_out_dir = Path("data/fbank") + # number of workers in dataloader + num_workers = 20 + + # number of seconds in a batch + batch_duration = 600 + + subsets = ("DEV", "TEST") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + for partition in subsets: + cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{in_out_dir}/feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +def main(): + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_gigaspeech_dev_test() + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py new file mode 100755 index 000000000..9dd3c046d --- /dev/null +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the XL subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + return parser + + +def compute_fbank_gigaspeech_splits(args): + num_splits = args.num_splits + output_dir = "data/fbank/XL_split" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = 8 # num_digits is fixed by lhotse split-lazy + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Computing features") + + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/feats_XL_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + ) + + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + logging.info(f"Saved to {cuts_path}") + + +def main(): + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + + log_filename = "log-compute_fbank_gigaspeech_splits" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + log_filename = f"{log_filename}-{date_time}" + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=logging.INFO, + filemode="w", + ) + + console = logging.StreamHandler() + console.setLevel(logging.INFO) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + compute_fbank_gigaspeech_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_musan.py b/egs/gigaspeech/ASR/local/compute_fbank_musan.py new file mode 100755 index 000000000..219f4bdca --- /dev/null +++ b/egs/gigaspeech/ASR/local/compute_fbank_musan.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + combine, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_musan(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + # number of workers in dataloader + num_workers = 10 + + # number of seconds in a batch + batch_duration = 600 + + dataset_parts = ( + "music", + "speech", + "noise", + ) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, output_dir=src_dir + ) + assert manifests is not None + + musan_cuts_path = output_dir / "cuts_musan.json.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + musan_cuts = ( + CutSet.from_manifests( + recordings=combine( + part["recordings"] for part in manifests.values() + ) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/feats_musan", + num_workers=num_workers, + batch_duration=batch_duration, + ) + ) + musan_cuts.to_json(musan_cuts_path) + + +def main(): + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_musan() + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/gigaspeech/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/prepare_lang.py b/egs/gigaspeech/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/gigaspeech/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/prepare_lang_bpe.py b/egs/gigaspeech/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/gigaspeech/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py new file mode 100755 index 000000000..0cec82ad5 --- /dev/null +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2021 Johns Hopkins University (Piotr Żelasko) +# Copyright 2021 Xiaomi Corp. (Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from pathlib import Path + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + +# Similar text filtering and normalization procedure as in: +# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh + + +def normalize_text( + utt: str, + punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), + whitespace_pattern=re.compile(r"\s\s+"), +) -> str: + return whitespace_pattern.sub(" ", punct_pattern.sub("", utt)) + + +def has_no_oov( + sup: SupervisionSegment, + oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>"), +) -> bool: + return oov_pattern.search(sup.text) is None + + +def preprocess_giga_speech(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + dataset_parts = ( + "DEV", + "TEST", + "XL", + ) + + logging.info("Loading manifest (may take 4 minutes)") + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix="gigaspeech", + suffix="jsonl.gz", + ) + assert manifests is not None + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + # Note this step makes the recipe different than LibriSpeech: + # We must filter out some utterances and remove punctuation + # to be consistent with Kaldi. + logging.info("Filtering OOV utterances from supervisions") + m["supervisions"] = m["supervisions"].filter(has_no_oov) + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + sup.text = normalize_text(sup.text) + + # Create long-recording cut manifests. + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + # Run data augmentation that needs to be done in the + # time domain. + if partition not in ["DEV", "TEST"]: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 8 minutes and saving may take 20 minutes)" + ) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + + preprocess_giga_speech() + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/ASR/local/train_bpe_model.py b/egs/gigaspeech/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/gigaspeech/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh new file mode 100755 index 000000000..fd2532741 --- /dev/null +++ b/egs/gigaspeech/ASR/prepare.sh @@ -0,0 +1,325 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=0 +stop_stage=100 + +# Split XL subset to a number of pieces (about 2000) +# This is to avoid OOM during feature extraction. +num_per_split=50 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/GigaSpeech +# You can find audio, dict, GigaSpeech.json inside it. +# You can apply for the download credentials by following +# https://github.com/SpeechColab/GigaSpeech#download +# +# - $dl_dir/lm +# This directory contains the language model downloaded from +# https://huggingface.co/wgb14/gigaspeech_lm +# +# - 3gram_pruned_1e7.arpa.gz +# - 4gram.arpa.gz +# - lexicon.txt +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "stage -1: Download LM" + # We assume that you have installed the git-lfs, if not, you could install it + # using: `sudo apt-get install git-lfs && git-lfs install` + [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm + git clone https://huggingface.co/wgb14/gigaspeech_lm $dl_dir/lm + gunzip -c $dl_dir/lm/3gram_pruned_1e7.arpa.gz > $dl_dir/lm/3gram_pruned_1e7.arpa + gunzip -c $dl_dir/lm/4gram.arpa.gz > $dl_dir/lm/4gram.arpa +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + [ ! -e $dl_dir/GigaSpeech ] && mkdir -p $dl_dir/GigaSpeech + + # If you have pre-downloaded it to /path/to/GigaSpeech, + # you can create a symlink + # + # ln -sfv /path/to/GigaSpeech $dl_dir/GigaSpeech + # + if [ ! -d $dl_dir/GigaSpeech/audio ] && [ ! -f $dl_dir/GigaSpeech.json ]; then + # Check credentials. + if [ ! -f $dl_dir/password ]; then + echo -n "$0: Please apply for the download credentials by following" + echo -n "https://github.com/SpeechColab/GigaSpeech#download" + echo " and save it to $dl_dir/password." + exit 1; + fi + PASSWORD=`cat $dl_dir/password 2>/dev/null` + if [ -z "$PASSWORD" ]; then + echo "$0: Error, $dl_dir/password is empty." + exit 1; + fi + PASSWORD_MD5=`echo $PASSWORD | md5sum | cut -d ' ' -f 1` + if [[ $PASSWORD_MD5 != "dfbf0cde1a3ce23749d8d81e492741b8" ]]; then + echo "$0: Error, invalid $dl_dir/password." + exit 1; + fi + # Download XL, DEV and TEST sets by default. + lhotse download gigaspeech --subset auto --host tsinghua \ + $dl_dir/password $dl_dir/GigaSpeech + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare GigaSpeech manifest (may take 15 minutes)" + # We assume that you have downloaded the GigaSpeech corpus + # to $dl_dir/GigaSpeech + mkdir -p data/manifests + lhotse prepare gigaspeech --subset auto -j $nj \ + $dl_dir/GigaSpeech data/manifests +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "State 3: Preprocess GigaSpeech manifest" + if [ ! -f data/fbank/.preprocess_complete ]; then + python3 ./local/preprocess_gigaspeech.py + touch data/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)" + python3 ./local/compute_fbank_gigaspeech_dev_test.py +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split XL subset into pieces (may take 30 minutes)" + split_dir=data/fbank/XL_split + if [ ! -f $split_dir/.split_completed ]; then + lhotse split-lazy ./data/fbank/cuts_XL_raw.jsonl.gz $split_dir $num_per_split + touch $split_dir/.split_completed + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compute features for XL" + num_splits=$(find data/fbank/XL_split -name "cuts_XL_raw.*.jsonl.gz" | wc -l) + python3 ./local/compute_fbank_gigaspeech_splits.py \ + --num-workers 20 \ + --batch-duration 600 \ + --num-splits $num_splits +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Combine features for XL (may take 3 hours)" + if [ ! -f data/fbank/cuts_XL.jsonl.gz ]; then + pieces=$(find data/fbank/XL_split -name "cuts_XL.*.jsonl.gz") + lhotse combine $pieces data/fbank/cuts_XL.jsonl.gz + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Compute fbank for musan" + mkdir -p data/fbank + ./local/compute_fbank_musan.py +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/transcript_words.txt ]; then + gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ + | jq '.text' \ + | sed 's/"//g' \ + > $lang_dir/transcript_words.txt + + # Delete utterances with garbage meta tags + garbage_utterance_tags=" " + for tag in $garbage_utterance_tags; do + sed -i "/${tag}/d" $lang_dir/transcript_words.txt + done + + # Delete punctuations in utterances + punctuation_tags=" " + for tag in $punctuation_tags; do + sed -i "s/${tag}//g" $lang_dir/transcript_words.txt + done + + # Ensure space only appears once + sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/{words.txt,transcript_words.txt} $lang_dir + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare bigram P" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3gram_pruned_1e7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4gram.arpa > data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done +fi diff --git a/egs/gigaspeech/ASR/shared b/egs/gigaspeech/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/gigaspeech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/icefall/decode.py b/icefall/decode.py index d3e420eec..94f3e88ba 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -630,15 +630,37 @@ def rescore_with_n_best_list( assert G.device == device assert hasattr(G, "aux_labels") is False - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # nbest.fsa.scores are all 0s at this point + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + try: + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + nbest = nbest.intersect(lattice) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info(f"num_paths before decreasing: {num_paths}") + num_paths = int(num_paths / 2) + if loop_count >= max_loop_count or num_paths <= 0: + logging.info( + "Return None as the resulting lattice is too large." + ) + return None + logging.info( + "This OOM is not an error. You can ignore it. " + "If your model does not converge well, or --max-duration " + "is too large, or the input sound file is difficult to " + "decode, you will meet this exception." + ) + logging.info(f"num_paths after decreasing: {num_paths}") + loop_count += 1 - nbest = nbest.intersect(lattice) # Now nbest.fsa has its scores set assert hasattr(nbest.fsa, "lm_scores") @@ -824,15 +846,37 @@ def rescore_with_attention_decoder( ngram_lm_scale_attention_scale and the value is the best decoding path for each utterance in the lattice. """ - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # nbest.fsa.scores are all 0s at this point + max_loop_count = 10 + loop_count = 0 + while loop_count <= max_loop_count: + try: + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + nbest = nbest.intersect(lattice) + break + except RuntimeError as e: + logging.info(f"Caught exception:\n{e}\n") + logging.info(f"num_paths before decreasing: {num_paths}") + num_paths = int(num_paths / 2) + if loop_count >= max_loop_count or num_paths <= 0: + logging.info( + "Return None as the resulting lattice is too large." + ) + return None + logging.info( + "This OOM is not an error. You can ignore it. " + "If your model does not converge well, or --max-duration " + "is too large, or the input sound file is difficult to " + "decode, you will meet this exception." + ) + logging.info(f"num_paths after decreasing: {num_paths}") + loop_count += 1 - nbest = nbest.intersect(lattice) # Now nbest.fsa has its scores set. # Also, nbest.fsa inherits the attributes from `lattice`. assert hasattr(nbest.fsa, "lm_scores") From 524f3aa0152b7f3c69089943b4e68db2eb0abcb7 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 14 Apr 2022 16:41:52 +0800 Subject: [PATCH 221/234] update test functions for emformer. --- .../emformer.py | 9 +- .../test_emformer.py | 184 +++++++++++++++--- 2 files changed, 162 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 67e9f5891..9eb5b966f 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -85,8 +85,6 @@ class EmformerAttention(nn.Module): Embedding dimension. nhead (int): Number of attention heads in each Emformer layer. - dropout (float, optional): - Dropout probability. (Default: 0.0) weight_init_gain (float or None, optional): Scale factor to apply when initializing attention module parameters. (Default: ``None``) @@ -100,7 +98,6 @@ class EmformerAttention(nn.Module): self, embed_dim: int, nhead: int, - dropout: float = 0.0, weight_init_gain: Optional[float] = None, tanh_on_mem: bool = False, negative_inf: float = -1e8, @@ -115,7 +112,6 @@ class EmformerAttention(nn.Module): self.embed_dim = embed_dim self.nhead = nhead - self.dropout = dropout self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf @@ -183,9 +179,7 @@ class EmformerAttention(nn.Module): attention_probs = nn.functional.softmax( attention_weights_float, dim=-1 ).type_as(attention_weights) - # attention_probs = nn.functional.dropout( - # attention_probs, p=float(self.dropout), training=self.training - # ) + return attention_probs def _forward_impl( @@ -512,7 +506,6 @@ class EmformerLayer(nn.Module): self.attention = EmformerAttention( embed_dim=d_model, nhead=nhead, - dropout=dropout, weight_init_gain=weight_init_gain, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index 5e08640d3..abc023bb7 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -362,8 +362,9 @@ def test_emformer_attention_forward_infer_consistency(): left_context_length=L, right_context_length=R, max_memory_size=M, - dropout=0.0, + dropout=0.1, ) + encoder.eval() encoder_layer = encoder.emformer_layers[0] x = torch.randn(U + R, 1, D) @@ -415,12 +416,15 @@ def test_emformer_attention_forward_infer_consistency(): chunk_memory, state, ) - infer_output_utterance = infer_output_right_context_utterance[ + infer_output_chunk = infer_output_right_context_utterance[ chunk_right_context.size(0) : # noqa ] - print( - infer_output_utterance - - forward_output_utterance[start_idx:end_idx] + forward_output_chunk = forward_output_utterance[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-6, + rtol=0.0, ) @@ -444,8 +448,9 @@ def test_emformer_layer_forward_infer_consistency(): left_context_length=L, right_context_length=R, max_memory_size=M, - dropout=0.0, + dropout=0.1, ) + encoder.eval() encoder_layer = encoder.emformer_layers[0] x = torch.randn(U + R, 1, D) @@ -485,7 +490,7 @@ def test_emformer_layer_forward_infer_consistency(): else torch.empty(0).to(dtype=x.dtype, device=x.device) ) ( - infer_output_utterance, + infer_output_chunk, infer_right_context, infer_output_memory, state, @@ -496,9 +501,12 @@ def test_emformer_layer_forward_infer_consistency(): chunk_memory, state, ) - print( - infer_output_utterance - - forward_output_utterance[start_idx:end_idx] + forward_output_chunk = forward_output_utterance[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, ) @@ -522,8 +530,9 @@ def test_emformer_encoder_forward_infer_consistency(): left_context_length=L, right_context_length=R, max_memory_size=M, - dropout=0.0, + dropout=0.1, ) + encoder.eval() x = torch.randn(U + R, 1, D) lengths = torch.tensor([U + R]) @@ -537,23 +546,152 @@ def test_emformer_encoder_forward_infer_consistency(): chunk = x[start_idx : end_idx + R] # noqa chunk_right_context = x[end_idx : end_idx + R] # noqa chunk_length = torch.tensor([chunk_length]) - infer_output, infer_output_lengths, states = encoder.infer( + infer_output_chunk, infer_output_lengths, states = encoder.infer( chunk, chunk_length, states, ) - print(infer_output - forward_output[start_idx:end_idx]) + forward_output_chunk = forward_output[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, + ) + + +def test_emformer_infer_batch_single_consistency(): + """Test consistency of cached states and output logits between single + utterance inference and batch inference.""" + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + num_chunks = 3 + U = num_chunks * chunk_length + L, R = 128, 4 + B, D = 2, 256 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + model.eval() + + def save_states(states): + saved_states = [] + for layer_idx in range(len(states)): + layer_state = [] + layer_state.append(states[layer_idx][0].clone()) # memory + layer_state.append( + states[layer_idx][1].clone() + ) # left_context_key + layer_state.append( + states[layer_idx][2].clone() + ) # left_context_val + layer_state.append(states[layer_idx][3].clone()) # past_length + saved_states.append(layer_state) + return saved_states + + def assert_states_equal(saved_states, states, sample_idx): + for layer_idx in range(len(saved_states)): + # assert eqaul memory + assert torch.allclose( + states[layer_idx][0], + saved_states[layer_idx][0][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert equal left_context_key + assert torch.allclose( + states[layer_idx][1], + saved_states[layer_idx][1][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert equal left_context_val + assert torch.allclose( + states[layer_idx][2], + saved_states[layer_idx][2][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert eqaul past_length + assert torch.equal( + states[layer_idx][3], + saved_states[layer_idx][3][ + :, sample_idx : sample_idx + 1 # noqa + ], + ) + + x = torch.randn(B, U + R + 3, num_features) + batch_logits = [] + batch_states = [] + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk_length + R + 3]).expand(B) + logits, output_lengths, states = model.infer(chunk, lengths, states) + batch_logits.append(logits) + batch_states.append(save_states(states)) + batch_logits = torch.cat(batch_logits, dim=1) + + single_logits = [] + for sample_idx in range(B): + sample = x[sample_idx : sample_idx + 1] # noqa + chunk_logits = [] + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = sample[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk_length + R + 3]) + logits, output_lengths, states = model.infer( + chunk, lengths, states + ) + chunk_logits.append(logits) + + assert_states_equal(batch_states[chunk_idx], states, sample_idx) + + chunk_logits = torch.cat(chunk_logits, dim=1) + single_logits.append(chunk_logits) + single_logits = torch.cat(single_logits, dim=0) + + assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) if __name__ == "__main__": - # test_emformer_attention_forward() - # test_emformer_attention_infer() - # test_emformer_layer_forward() - # test_emformer_layer_infer() - # test_emformer_encoder_forward() - # test_emformer_encoder_infer() - # test_emformer_forward() - # test_emformer_infer() - # test_emformer_attention_forward_infer_consistency() - # test_emformer_layer_forward_infer_consistency() + test_emformer_attention_forward() + test_emformer_attention_infer() + test_emformer_layer_forward() + test_emformer_layer_infer() + test_emformer_encoder_forward() + test_emformer_encoder_infer() + test_emformer_forward() + test_emformer_infer() + test_emformer_attention_forward_infer_consistency() + test_emformer_layer_forward_infer_consistency() test_emformer_encoder_forward_infer_consistency() + test_emformer_infer_batch_single_consistency() From 32420cc3e4c287ab4811ee498810bb4153015d48 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 14 Apr 2022 17:07:47 +0800 Subject: [PATCH 222/234] Add test functions for torchaudio emformer codes. --- .../ASR/transducer_emformer/test_emformer.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/egs/librispeech/ASR/transducer_emformer/test_emformer.py b/egs/librispeech/ASR/transducer_emformer/test_emformer.py index d8c7b37e2..d711df957 100755 --- a/egs/librispeech/ASR/transducer_emformer/test_emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/test_emformer.py @@ -65,8 +65,135 @@ def test_emformer(): print(f"Number of encoder parameters: {num_param}") +def test_emformer_infer_batch_single_consistency(): + """Test consistency of cached states and output logits between single + utterance inference and batch inference.""" + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + num_chunks = 3 + U = num_chunks * chunk_length + L, R = 128, 4 + B, D = 2, 256 + num_encoder_layers = 4 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + segment_length=chunk_length, + subsampling_factor=4, + d_model=D, + nhead=4, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + model.eval() + + def save_states(states): + saved_states = [] + for layer_idx in range(len(states)): + layer_state = [] + layer_state.append(states[layer_idx][0].clone()) # memory + layer_state.append( + states[layer_idx][1].clone() + ) # left_context_key + layer_state.append( + states[layer_idx][2].clone() + ) # left_context_val + layer_state.append(states[layer_idx][3].clone()) # past_length + saved_states.append(layer_state) + return saved_states + + def assert_states_equal(saved_states, states, sample_idx): + for layer_idx in range(len(saved_states)): + # assert eqaul memory + assert torch.allclose( + states[layer_idx][0], + saved_states[layer_idx][0][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert equal left_context_key + assert torch.allclose( + states[layer_idx][1], + saved_states[layer_idx][1][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert equal left_context_val + assert torch.allclose( + states[layer_idx][2], + saved_states[layer_idx][2][ + :, sample_idx : sample_idx + 1 # noqa + ], + atol=1e-5, + rtol=0.0, + ) + # assert eqaul past_length + assert torch.equal( + states[layer_idx][3], + saved_states[layer_idx][3][ + :, sample_idx : sample_idx + 1 # noqa + ], + ) + + x = torch.randn(B, U + R + 3, num_features) + batch_logits = [] + batch_states = [] + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk_length + R + 3]).expand(B) + logits, output_lengths, states = model.streaming_forward( + chunk, lengths, states + ) + batch_logits.append(logits) + batch_states.append(save_states(states)) + batch_logits = torch.cat(batch_logits, dim=1) + + single_logits = [] + for sample_idx in range(B): + sample = x[sample_idx : sample_idx + 1] # noqa + chunk_logits = [] + states = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = sample[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk_length + R + 3]) + logits, output_lengths, states = model.streaming_forward( + chunk, lengths, states + ) + chunk_logits.append(logits) + + assert_states_equal(batch_states[chunk_idx], states, sample_idx) + + chunk_logits = torch.cat(chunk_logits, dim=1) + single_logits.append(chunk_logits) + single_logits = torch.cat(single_logits, dim=0) + + assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) + + def main(): test_emformer() + test_emformer_infer_batch_single_consistency() if __name__ == "__main__": From df7919f4bf0f3c84977363551a65a61b093ca2dd Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 14 Apr 2022 19:16:30 +0800 Subject: [PATCH 223/234] update test functions for conv_emformer_transducer/emformer.py --- .../ASR/conv_emformer_transducer/emformer.py | 12 +- .../conv_emformer_transducer/test_emformer.py | 213 ++++++++++++++++++ 2 files changed, 215 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py index e9ce56aa7..14e106460 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/emformer.py @@ -14,8 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# It is modified based on -# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. +# It is modified based on https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. # noqa import math import warnings @@ -56,8 +55,6 @@ class EmformerAttention(nn.Module): Embedding dimension. nhead (int): Number of attention heads in each Emformer layer. - dropout (float, optional): - Dropout probability. (Default: 0.0) tanh_on_mem (bool, optional): If ``True``, applies tanh to memory elements. (Default: ``False``) negative_inf (float, optional): @@ -68,7 +65,6 @@ class EmformerAttention(nn.Module): self, embed_dim: int, nhead: int, - dropout: float = 0.0, tanh_on_mem: bool = False, negative_inf: float = -1e8, ): @@ -82,7 +78,6 @@ class EmformerAttention(nn.Module): self.embed_dim = embed_dim self.nhead = nhead - self.dropout = dropout self.tanh_on_mem = tanh_on_mem self.negative_inf = negative_inf @@ -154,9 +149,7 @@ class EmformerAttention(nn.Module): attention_probs = nn.functional.softmax( attention_weights_float, dim=-1 ).type_as(attention_weights) - attention_probs = nn.functional.dropout( - attention_probs, p=float(self.dropout), training=self.training - ) + return attention_probs def _forward_impl( @@ -481,7 +474,6 @@ class EmformerLayer(nn.Module): self.attention = EmformerAttention( embed_dim=d_model, nhead=nhead, - dropout=0.0, tanh_on_mem=tanh_on_mem, negative_inf=negative_inf, ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py index 41e911e17..971abca97 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer/test_emformer.py @@ -366,6 +366,216 @@ def test_emformer_infer(): assert conv_cache.shape == (B, D, K - 1) +def test_emformer_encoder_layer_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 1 + memory_sizes = [0, 3] + K = 3 + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.1, + cnn_module_kernel=K, + causal=True, + ) + encoder.eval() + encoder_layer = encoder.emformer_layers[0] + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U]) + right_context = encoder._gen_right_context(x) + utterance = x[: x.size(0) - R] + attention_mask = encoder._gen_attention_mask(utterance) + memory = ( + encoder.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + forward_output_utterance, + forward_output_right_context, + forward_output_memory, + ) = encoder_layer( + utterance, + lengths, + right_context, + memory, + attention_mask, + ) + + state = None + conv_cache = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx:end_idx] + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + chunk_memory = ( + encoder.init_memory_op(chunk.permute(1, 2, 0)).permute(2, 0, 1) + if encoder.use_memory + else torch.empty(0).to(dtype=x.dtype, device=x.device) + ) + ( + infer_output_chunk, + infer_right_context, + infer_output_memory, + state, + conv_cache, + ) = encoder_layer.infer( + chunk, + chunk_length, + chunk_right_context, + chunk_memory, + state, + conv_cache, + ) + forward_output_chunk = forward_output_utterance[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, + ) + + +def test_emformer_encoder_forward_infer_consistency(): + from emformer import EmformerEncoder + + chunk_length = 4 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 1, 2 + D = 256 + num_encoder_layers = 3 + K = 3 + memory_sizes = [0, 3] + + for M in memory_sizes: + encoder = EmformerEncoder( + chunk_length=chunk_length, + d_model=D, + dim_feedforward=1024, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.1, + cnn_module_kernel=K, + causal=True, + ) + encoder.eval() + + x = torch.randn(U + R, 1, D) + lengths = torch.tensor([U + R]) + + forward_output, forward_output_lengths = encoder(x, lengths) + + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[start_idx : end_idx + R] # noqa + chunk_right_context = x[end_idx : end_idx + R] # noqa + chunk_length = torch.tensor([chunk_length]) + ( + infer_output_chunk, + infer_output_lengths, + states, + conv_caches, + ) = encoder.infer( + chunk, + chunk_length, + states, + conv_caches, + ) + forward_output_chunk = forward_output[start_idx:end_idx] + assert torch.allclose( + infer_output_chunk, + forward_output_chunk, + atol=1e-5, + rtol=0.0, + ) + + +def test_emformer_forward_infer_consistency(): + from emformer import Emformer + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + num_chunks = 3 + U = chunk_length * num_chunks + L, R = 128, 4 + D = 256 + num_encoder_layers = 2 + K = 3 + memory_sizes = [0, 3] + + for M in memory_sizes: + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + cnn_module_kernel=K, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + dropout=0.1, + vgg_frontend=False, + causal=True, + ) + model.eval() + + x = torch.randn(1, U + R + 3, num_features) + x_lens = torch.tensor([x.size(1)]) + + # forward mode + forward_logits, _ = model(x, x_lens) + + states = None + conv_caches = None + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_length + end_idx = start_idx + chunk_length + chunk = x[:, start_idx : end_idx + R + 3] # noqa + lengths = torch.tensor([chunk.size(1)]) + ( + infer_chunk_logits, + output_lengths, + states, + conv_caches, + ) = model.infer(chunk, lengths, states, conv_caches) + forward_chunk_logits = forward_logits[ + :, start_idx // 4 : end_idx // 4 # noqa + ] + assert torch.allclose( + infer_chunk_logits, + forward_chunk_logits, + atol=1e-5, + rtol=0.0, + ) + + if __name__ == "__main__": test_emformer_attention_forward() test_emformer_attention_infer() @@ -375,3 +585,6 @@ if __name__ == "__main__": test_emformer_encoder_infer() test_emformer_forward() test_emformer_infer() + test_emformer_encoder_layer_forward_infer_consistency() + test_emformer_encoder_forward_infer_consistency() + test_emformer_forward_infer_consistency() From 021c79824eec174e847358c1a18b5d5e167f1bc4 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 19 Apr 2022 17:23:46 +0800 Subject: [PATCH 224/234] Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes --- egs/librispeech/ASR/local/compile_lg.py | 141 ++++++++++++++++++ egs/librispeech/ASR/prepare.sh | 11 ++ .../beam_search.py | 121 +++++++++++++-- .../ASR/pruned_transducer_stateless/decode.py | 90 ++++++++++- 4 files changed, 344 insertions(+), 19 deletions(-) create mode 100755 egs/librispeech/ASR/local/compile_lg.py diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py new file mode 100755 index 000000000..45c4b7f5f --- /dev/null +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing LG. + """ + lexicon = Lexicon(lang_dir) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "LG.pt").is_file(): + logging.info(f"{lang_dir}/LG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir) + logging.info(f"Saving LG.pt to {lang_dir}") + torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 1bbf7bbcf..6b61c6b57 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,3 +242,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index ef1f399c6..100aeaa6e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -22,7 +22,7 @@ import k2 import torch from model import Transducer -from icefall.decode import one_best_decoding +from icefall.decode import Nbest, one_best_decoding from icefall.utils import get_texts @@ -34,6 +34,7 @@ def fast_beam_search( beam: float, max_states: int, max_contexts: int, + use_max: bool = False, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -53,6 +54,9 @@ def fast_beam_search( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -104,9 +108,67 @@ def fast_beam_search( decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + if use_max: + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + else: + num_paths = 200 + use_double_scores = True + nbest_scale = 0.8 + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, log_semiring=True + ) + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + hyps = get_texts(best_path) + return hyps def greedy_search( @@ -280,7 +342,7 @@ class HypothesisList(object): def data(self) -> Dict[str, Hypothesis]: return self._data - def add(self, hyp: Hypothesis) -> None: + def add(self, hyp: Hypothesis, use_max: bool = False) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -289,13 +351,20 @@ class HypothesisList(object): Args: hyp: The hypothesis to be added. + use_max: + True to select the hypothesis with the larger log_prob in case there + already exists a hypothesis whose `ys` equals to `hyp.ys`. + False to use log_add. """ key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + if use_max: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + else: + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -403,6 +472,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -413,6 +483,9 @@ def modified_beam_search( Output from the encoder. Its shape is (N, T, C). beam: Number of active paths during the beam search. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. @@ -432,7 +505,8 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) for t in range(T): @@ -517,6 +591,7 @@ def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -532,6 +607,9 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -553,12 +631,13 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) for t in range(T): # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + current_encoder_out = encoder_out[:, t:t + 1, :].unsqueeze(2) # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) # fmt: on A = list(B) @@ -611,7 +690,7 @@ def _deprecated_modified_beam_search( new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) + B.add(new_hyp, use_max=use_max) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks @@ -623,6 +702,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -636,6 +716,9 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -661,7 +744,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -720,7 +805,10 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob), + use_max=use_max, + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -729,7 +817,10 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + A.add( + Hypothesis(ys=new_ys, log_prob=new_log_prob), + use_max=use_max, + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 0e3b0f197..5082ab71a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -53,6 +53,19 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) fast beam search using LG +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --use-LG True \ + --use-max False \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 8 \ + --max-contexts 8 \ + --max-states 64 """ @@ -81,10 +94,12 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -136,6 +151,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -167,6 +189,36 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--use-LG", + type=str2bool, + default=False, + help="""Whether to use an LG graph for FSA-based beam search. + Used only when --decoding_method is fast_beam_search. If setting true, + it assumes there is an LG.pt file in lang_dir.""", + ) + + parser.add_argument( + "--use-max", + type=str2bool, + default=False, + help="""If True, use max-op to select the hypothesis that have the + max log_prob in case of duplicate hypotheses. + If False, use log_add. + Used only for beam_search, modified_beam_search, and fast_beam_search + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -206,6 +258,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -229,6 +282,8 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -260,9 +315,14 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + use_max=params.use_max, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.use_LG: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + else: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -278,6 +338,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, beam=params.beam_size, + use_max=params.use_max, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -299,6 +360,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + use_max=params.use_max, ) else: raise ValueError( @@ -325,6 +387,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -338,6 +401,8 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -368,8 +433,9 @@ def decode_dataset( params=params, model=model, sp=sp, - decoding_graph=decoding_graph, batch=batch, + word_table=word_table, + decoding_graph=decoding_graph, ) for name, hyps in hyps_dict.items(): @@ -460,13 +526,16 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if "fast_beam_search" in params.decoding_method: + params.suffix += f"-use-LG-{params.use_LG}" params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-use-max-{params.use_max}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" ) + params.suffix += f"-use-max-{params.use_max}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -527,9 +596,21 @@ def main(): model.device = device if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.use_LG: + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + decoding_graph = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/LG.pt", map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -551,6 +632,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) From 328ad280a40d84f44e0180e8470187635455f9d9 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 19 Apr 2022 17:58:51 +0800 Subject: [PATCH 225/234] Support state stacking and unstacking operations for emformer_pruned_transducer_stateless/emformer.py --- .../emformer.py | 75 ++++++++++++++++++- .../test_emformer.py | 39 ++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 9eb5b966f..b6f93b4c7 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -77,6 +77,75 @@ def _gen_attention_mask_block( return torch.cat(mask_block, dim=1) +def unstack_states( + states: List[List[torch.Tensor]], +) -> List[List[List[torch.Tensor]]]: + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, were the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A list-of-list of tensors. ``len(states)`` equals to number of + layers in the emformer. ``states[i]]`` contains the states for + the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape + ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + """ + batch_size = states[0][0].size(1) + num_layers = len(states) + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [[] for _ in range(num_layers)] + + for li, layer in enumerate(states): + for s in layer: + s_list = s.unbind(dim=1) + # We will use stack(dim=1) later in stack_states() + for bi, b in enumerate(ans): + b[li].append(s_list[bi]) + return ans + + +def stack_states( + state_list: List[List[List[torch.Tensor]]], +) -> List[List[torch.Tensor]]: + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + Returns: + Return a new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + ans = [] + for layer in state_list[0]: + # layer is a list of tensors + if batch_size > 1: + ans.append([[s] for s in layer]) + # Note: We will stack ans[layer][s][] later to get ans[layer][s] + else: + ans.append([s.unsqueeze(1) for s in layer]) + + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states): + for si, s in enumerate(layer): + ans[li][si].append(s) + if b == batch_size - 1: + ans[li][si] = torch.stack(ans[li][si], dim=1) + # We will use unbind(dim=1) later in unstack_states() + return ans + + class EmformerAttention(nn.Module): r"""Emformer layer attention module. @@ -424,9 +493,9 @@ class EmformerAttention(nn.Module): # key, value: [memory, right context, left context, uttrance] KV = ( memory.size(0) - + right_context.size(0) - + left_context_key.size(0) - + utterance.size(0) + + right_context.size(0) # noqa + + left_context_key.size(0) # noqa + + utterance.size(0) # noqa ) attention_mask = torch.zeros(Q, KV).to( dtype=torch.bool, device=utterance.device diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py index abc023bb7..ecfe24c61 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/test_emformer.py @@ -682,6 +682,44 @@ def test_emformer_infer_batch_single_consistency(): assert torch.allclose(batch_logits, single_logits, atol=1e-5, rtol=0.0) +def test_emformer_infer_states_stack(): + from emformer import Emformer, unstack_states, stack_states + + num_features = 80 + output_dim = 1000 + chunk_length = 8 + U = chunk_length + L, R = 128, 4 + B, D = 2, 256 + num_encoder_layers = 2 + for use_memory in [True, False]: + if use_memory: + M = 3 + else: + M = 0 + model = Emformer( + num_features=num_features, + output_dim=output_dim, + chunk_length=chunk_length, + subsampling_factor=4, + d_model=D, + num_encoder_layers=num_encoder_layers, + left_context_length=L, + right_context_length=R, + max_memory_size=M, + vgg_frontend=False, + ) + + x = torch.randn(B, U + R + 3, num_features) + x_lens = torch.full((B, ), U + R + 3) + logits, output_lengths, states = model.infer(x, x_lens,) + states2 = stack_states(unstack_states(states)) + + for ss, ss2 in zip(states, states2): + for s, s2 in zip(ss, ss2): + assert torch.allclose(s, s2), f"{s.sum()}, {s2.sum()}" + + if __name__ == "__main__": test_emformer_attention_forward() test_emformer_attention_infer() @@ -695,3 +733,4 @@ if __name__ == "__main__": test_emformer_layer_forward_infer_consistency() test_emformer_encoder_forward_infer_consistency() test_emformer_infer_batch_single_consistency() + test_emformer_infer_states_stack() From fce7f3cd9a486405ee008bcbe4999264f27774a3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 19 Apr 2022 18:47:13 +0800 Subject: [PATCH 226/234] Support computing RNN-T loss with torchaudio (#316) --- ...-pruned-transducer-stateless-2022-03-12.sh | 47 ++ ...speech-transducer-stateless2-2022-04-19.sh | 47 ++ .../scripts/run-pre-trained-conformer-ctc.sh | 46 ++ ...d-transducer-stateless-librispeech-100h.sh | 47 ++ ...d-transducer-stateless-librispeech-960h.sh | 47 ++ ...transducer-stateless-modified-2-aishell.sh | 47 ++ ...d-transducer-stateless-modified-aishell.sh | 47 ++ .../run-pre-trained-transducer-stateless.sh | 60 ++ .github/scripts/run-pre-trained-transducer.sh | 32 + .../workflows/run-librispeech-2022-03-12.yml | 104 +-- ...peech-transducer-stateless2-2022-04-19.yml | 82 ++ .../run-pretrained-conformer-ctc.yml | 46 +- ...-transducer-stateless-librispeech-100h.yml | 95 +-- ...r-stateless-librispeech-multi-datasets.yml | 97 +-- ...ransducer-stateless-modified-2-aishell.yml | 96 +-- ...-transducer-stateless-modified-aishell.yml | 96 +-- .../run-pretrained-transducer-stateless.yml | 94 +-- .../workflows/run-pretrained-transducer.yml | 46 +- egs/librispeech/ASR/README.md | 15 +- egs/librispeech/ASR/RESULTS.md | 68 ++ .../ASR/pruned_transducer_stateless/train.py | 22 +- .../ASR/transducer_stateless/beam_search.py | 13 +- .../ASR/transducer_stateless/train.py | 22 +- .../ASR/transducer_stateless2/__init__.py | 0 .../transducer_stateless2/asr_datamodule.py | 1 + .../ASR/transducer_stateless2/beam_search.py | 1 + .../ASR/transducer_stateless2/conformer.py | 1 + .../ASR/transducer_stateless2/decode.py | 443 ++++++++++ .../ASR/transducer_stateless2/decoder.py | 1 + .../encoder_interface.py | 1 + .../ASR/transducer_stateless2/export.py | 181 ++++ .../ASR/transducer_stateless2/joiner.py | 67 ++ .../ASR/transducer_stateless2/model.py | 130 +++ .../ASR/transducer_stateless2/pretrained.py | 293 +++++++ .../ASR/transducer_stateless2/subsampling.py | 1 + .../ASR/transducer_stateless2/train.py | 779 ++++++++++++++++++ .../ASR/transducer_stateless2/transformer.py | 1 + 37 files changed, 2536 insertions(+), 680 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh create mode 100755 .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh create mode 100755 .github/scripts/run-pre-trained-conformer-ctc.sh create mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh create mode 100755 .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh create mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh create mode 100755 .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh create mode 100755 .github/scripts/run-pre-trained-transducer-stateless.sh create mode 100755 .github/scripts/run-pre-trained-transducer.sh create mode 100644 .github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml create mode 100644 egs/librispeech/ASR/transducer_stateless2/__init__.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/beam_search.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/conformer.py create mode 100755 egs/librispeech/ASR/transducer_stateless2/decode.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/decoder.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/encoder_interface.py create mode 100755 egs/librispeech/ASR/transducer_stateless2/export.py create mode 100644 egs/librispeech/ASR/transducer_stateless2/joiner.py create mode 100644 egs/librispeech/ASR/transducer_stateless2/model.py create mode 100755 egs/librispeech/ASR/transducer_stateless2/pretrained.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/subsampling.py create mode 100755 egs/librispeech/ASR/transducer_stateless2/train.py create mode 120000 egs/librispeech/ASR/transducer_stateless2/transformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh new file mode 100755 index 000000000..2387a16e2 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./pruned_transducer_stateless/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh new file mode 100755 index 000000000..102547c8b --- /dev/null +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless2-torchaudio-2022-04-19 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh new file mode 100755 index 000000000..96a072c46 --- /dev/null +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 +git lfs install +git clone $repo + +log "Downloading pre-trained model from $repo_url" +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.flac +ls -lh $repo/test_wavs/*.flac + +log "CTC decoding" + +./conformer_ctc/pretrained.py \ + --method ctc-decoding \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac + +log "HLG decoding" + +./conformer_ctc/pretrained.py \ + --method 1best \ + --num-classes 500 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh new file mode 100755 index 000000000..f484bd49a --- /dev/null +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh new file mode 100755 index 000000000..5501dcecd --- /dev/null +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh new file mode 100755 index 000000000..168aee766 --- /dev/null +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_modified-2/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_modified-2/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh new file mode 100755 index 000000000..9211b22eb --- /dev/null +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless_modified/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_modified/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $repo/test_wavs/BAC009S0764W0123.wav +done diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh new file mode 100755 index 000000000..cb57602e3 --- /dev/null +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search; do + log "$method" + + ./transducer_stateless_multi_datasets/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh new file mode 100755 index 000000000..5f8a5b3a5 --- /dev/null +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +log "Beam search decoding" + +./transducer/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 221104f8f..135285f15 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -40,11 +40,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -77,104 +72,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model - shell: bash - run: | - sudo apt-get -qq install git-lfs - mkdir -p ~/tmp - cd ~/tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - - - name: Display test files - shell: bash - run: | - sudo apt-get -qq install tree sox - tree ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - soxi ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav - ls -lh ~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav - - - name: Run greedy search decoding (max-sym-per-frame 1) + - name: Inference with pre-trained model shell: bash run: | + sudo apt-get -qq install git-lfs tree sox export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - cd egs/librispeech/ASR - ./pruned_transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 1 \ - --checkpoint $dir/exp/pretrained.pt \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - $dir/test_wavs/1089-134686-0001.wav \ - $dir/test_wavs/1221-135766-0001.wav \ - $dir/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 2) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - cd egs/librispeech/ASR - ./pruned_transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 2 \ - --checkpoint $dir/exp/pretrained.pt \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - $dir/test_wavs/1089-134686-0001.wav \ - $dir/test_wavs/1221-135766-0001.wav \ - $dir/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 3) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - cd egs/librispeech/ASR - ./pruned_transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 3 \ - --checkpoint $dir/exp/pretrained.pt \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - $dir/test_wavs/1089-134686-0001.wav \ - $dir/test_wavs/1221-135766-0001.wav \ - $dir/test_wavs/1221-135766-0002.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - cd egs/librispeech/ASR - ./pruned_transducer_stateless/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint $dir/exp/pretrained.pt \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - $dir/test_wavs/1089-134686-0001.wav \ - $dir/test_wavs/1221-135766-0001.wav \ - $dir/test_wavs/1221-135766-0002.wav - - - name: Run modified beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - dir=~/tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 - cd egs/librispeech/ASR - ./pruned_transducer_stateless/pretrained.py \ - --method modified_beam_search \ - --beam-size 4 \ - --checkpoint $dir/exp/pretrained.pt \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ - $dir/test_wavs/1089-134686-0001.wav \ - $dir/test_wavs/1221-135766-0001.wav \ - $dir/test_wavs/1221-135766-0002.wav + .github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml new file mode 100644 index 000000000..5871f926d --- /dev/null +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -0,0 +1,82 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-04-19 +# stateless transducer + torchaudio rnn-t loss + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_librispeech_2022_04_19: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }} + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + mkdir -p ~/tmp + cd ~/tmp + git clone https://github.com/csukuangfj/kaldifeat + cd kaldifeat + mkdir build + cd build + cmake -DCMAKE_BUILD_TYPE=Release .. + make -j2 _kaldifeat + + - name: Inference with pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + .github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index cd24c9c44..6575ceb65 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,48 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/librispeech/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 - cd .. - tree tmp - soxi tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac - ls -lh tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/*.flac - - - name: Run CTC decoding - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./conformer_ctc/pretrained.py \ - --num-classes 500 \ - --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/bpe.model \ - --method ctc-decoding \ - ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \ - ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \ - ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac - - - name: Run HLG decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./conformer_ctc/pretrained.py \ - --num-classes 500 \ - --checkpoint ./tmp/icefall-asr-conformer-ctc-bpe-500/exp/pretrained.pt \ - --words-file ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/words.txt \ - --HLG ./tmp/icefall-asr-conformer-ctc-bpe-500/data/lang_bpe_500/HLG.pt \ - ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1089-134686-0001.flac \ - ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0001.flac \ - ./tmp/icefall-asr-conformer-ctc-bpe-500/test_wavs/1221-135766-0002.flac + .github/scripts/run-pre-trained-conformer-ctc.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index b827ec82e..80ab356e6 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,97 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/librispeech/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21 - - cd .. - tree tmp - soxi tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/*.wav - - - name: Run greedy search decoding (max-sym-per-frame 1) - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 1 \ - --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 2) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 2 \ - --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 3) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 3 \ - --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav - - - name: Run modified beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method modified_beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-100h-transducer-stateless-multi-datasets-bpe-500-2022-02-21/test_wavs/1221-135766-0002.wav + .github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index ffd9bdaec..d2231750c 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,99 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/librispeech/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 - - - cd .. - tree tmp - soxi tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/*.wav - - - name: Run greedy search decoding (max-sym-per-frame 1) - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 1 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 2) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 2 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 3) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 3 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav - - - - name: Run modified beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless_multi_datasets/pretrained.py \ - --method modified_beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01/test_wavs/1221-135766-0002.wav + .github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 12652a22d..a84e804c6 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,98 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/aishell/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 - - cd .. - tree tmp - soxi tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav - ls -lh tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav - - - name: Run greedy search decoding (max-sym-per-frame 1) - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified-2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 1 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - name: Run greedy search decoding (max-sym-per-frame 2) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified-2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 2 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - name: Run greedy search decoding (max-sym-per-frame 3) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified-2/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 3 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified-2/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - - name: Run modified beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified-2/pretrained.py \ - --method modified_beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/BAC009S0764W0123.wav + .github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index aa69d1500..7fa48d15a 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,98 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/aishell/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 - - cd .. - tree tmp - soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav - ls -lh tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav - - - name: Run greedy search decoding (max-sym-per-frame 1) - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 1 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - name: Run greedy search decoding (max-sym-per-frame 2) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 2 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - name: Run greedy search decoding (max-sym-per-frame 3) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 3 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav - - - - name: Run modified beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/aishell/ASR - ./transducer_stateless_modified/pretrained.py \ - --method modified_beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ - --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ - ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav + .github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 535e46261..678e79339 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,96 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/librispeech/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07 - cd .. - tree tmp - soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav - - - name: Run greedy search decoding (max-sym-per-frame 1) - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 1 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 2) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 2 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav - - - name: Run greedy search decoding (max-sym-per-frame 3) - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless/pretrained.py \ - --method greedy_search \ - --max-sym-per-frame 3 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav - - - name: Run modified beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer_stateless/pretrained.py \ - --method modified_beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav + .github/scripts/run-pre-trained-transducer-stateless.sh diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 41e4cfe0d..781783bcf 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -39,11 +39,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -76,48 +71,11 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release .. make -j2 _kaldifeat - - name: Download pre-trained model + - name: Inference with pre-trained model shell: bash run: | sudo apt-get -qq install git-lfs tree sox - cd egs/librispeech/ASR - mkdir tmp - cd tmp - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23 - - cd .. - tree tmp - soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav - ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav - - - name: Run greedy search decoding - shell: bash - run: | export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer/pretrained.py \ - --method greedy_search \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav - - - name: Run beam search decoding - shell: bash - run: | - export PYTHONPATH=$PWD:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH - export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - cd egs/librispeech/ASR - ./transducer/pretrained.py \ - --method beam_search \ - --beam-size 4 \ - --checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \ - --bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \ - ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \ - ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav + .github/scripts/run-pre-trained-transducer.sh diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index b3e90a052..de9d6d50a 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -10,13 +10,14 @@ There are various folders containing the name `transducer` in this folder. The following table lists the differences among them. | | Encoder | Decoder | Comment | -|---------------------------------------|---------------------|--------------------|---------------------------------------------------| -| `transducer` | Conformer | LSTM | | -| `transducer_stateless` | Conformer | Embedding + Conv1d | | -| `transducer_lstm` | LSTM | LSTM | | -| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | -| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | -| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | +|---------------------------------------|---------------------|--------------------|-------------------------------------------------------| +| `transducer` | Conformer | LSTM | | +| `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss | +| `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | +| `transducer_lstm` | LSTM | LSTM | | +| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | +| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | +| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | The decoder in `transducer_stateless` is modified from the paper diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 3488535a6..ac8b3ec75 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -350,6 +350,74 @@ You can find a pretrained model by visiting +##### 2022-04-19 + +[transducer_stateless2](./transducer_stateless2) +This version uses torchaudio's RNN-T loss. + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|--------------------------------------------------------------------------------| +| greedy search (max sym per frame 1) | 2.65 | 6.30 | --epoch 59 --avg 10 --max-duration 600 | +| greedy search (max sym per frame 2) | 2.62 | 6.23 | --epoch 59 --avg 10 --max-duration 100 | +| greedy search (max sym per frame 3) | 2.62 | 6.23 | --epoch 59 --avg 10 --max-duration 100 | +| modified beam search | 2.63 | 6.15 | --epoch 59 --avg 10 --max-duration 100 --decoding-method modified_beam_search | +| beam search | 2.59 | 6.15 | --epoch 59 --avg 10 --max-duration 100 --decoding-method beam_search | + +**Note**: This model is trained with standard RNN-T loss. Neither modified transducer nor pruned RNN-T is used. +You can see that there is a performance degradation in WER when we limit the max symbol per frame to 1. + +The number of active paths in `modified_beam_search` and `beam_search` is 4. + +The training and decoding commands are: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./transducer_stateless2/train.py \ + --world-size 8 \ + --num-epochs 60 \ + --start-epoch 0 \ + --exp-dir transducer_stateless2/exp-2 \ + --full-libri 1 \ + --max-duration 300 \ + --lr-factor 5 + +epoch=59 +avg=10 +# greedy search +./transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./transducer_stateless2/exp-2 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --max-sym-per-frame 1 + +# modified beam search +./transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./transducer_stateless2/exp-2 \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + +# beam search +./transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./transducer_stateless2/exp-2 \ + --max-duration 100 \ + --decoding-method beam_search \ +``` + +The tensorboard log is at . + + +You can find a pre-trained model, decoding logs, and decoding results at + + + + ##### 2022-02-07 Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index f0ea12d62..c360d025a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -811,13 +811,23 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 + try: + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + logging.info( + f"Before removing short and long utterances: {num_in_total}" + ) + logging.info(f"After removing short and long utterances: {num_left}") + logging.info( + f"Removed {num_removed} utterances ({removed_percent:.5f}%)" + ) + except TypeError as e: + # You can ignore this error as previous versions of Lhotse work fine + # for the above code. In recent versions of Lhotse, it uses + # lazy filter, producing cutsets that don't have the __len__ method + logging.info(str(e)) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 7b4fac31d..388a8d67a 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Dict, List, Optional @@ -505,8 +506,10 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] @@ -613,8 +616,10 @@ def _deprecated_modified_beam_search( topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index d6827c17c..89f754b20 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -653,13 +653,23 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - num_left = len(train_cuts) - num_removed = num_in_total - num_left - removed_percent = num_removed / num_in_total * 100 + try: + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 - logging.info(f"Before removing short and long utterances: {num_in_total}") - logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + logging.info( + f"Before removing short and long utterances: {num_in_total}" + ) + logging.info(f"After removing short and long utterances: {num_left}") + logging.info( + f"Removed {num_removed} utterances ({removed_percent:.5f}%)" + ) + except TypeError as e: + # You can ignore this error as previous versions of Lhotse work fine + # for the above code. In recent versions of Lhotse, it uses + # lazy filter, producing cutsets that don't have the __len__ method + logging.info(str(e)) train_dl = librispeech.train_dataloaders(train_cuts) diff --git a/egs/librispeech/ASR/transducer_stateless2/__init__.py b/egs/librispeech/ASR/transducer_stateless2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/beam_search.py b/egs/librispeech/ASR/transducer_stateless2/beam_search.py new file mode 120000 index 000000000..08cb32ef7 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/beam_search.py @@ -0,0 +1 @@ +../transducer_stateless/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/conformer.py b/egs/librispeech/ASR/transducer_stateless2/conformer.py new file mode 120000 index 000000000..70a7ddf11 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/conformer.py @@ -0,0 +1 @@ +../transducer_stateless/conformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py new file mode 100755 index 000000000..08c61c2be --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./transducer_stateless2/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./transducer_stateless2/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./transducer_stateless2/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless2/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=13, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + hyp_list: List[List[int]] = [] + + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + else: + return {f"beam_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + model.device = device + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless2/decoder.py b/egs/librispeech/ASR/transducer_stateless2/decoder.py new file mode 120000 index 000000000..eada91097 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/decoder.py @@ -0,0 +1 @@ +../transducer_stateless/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/encoder_interface.py b/egs/librispeech/ASR/transducer_stateless2/encoder_interface.py new file mode 120000 index 000000000..aa5d0217a --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/encoder_interface.py @@ -0,0 +1 @@ +../transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py new file mode 100755 index 000000000..7a68f69ff --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer_stateless2/export.py \ + --exp-dir ./transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer_stateless2/decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./transducer_stateless2/decode.py \ + --exp-dir ./transducer_stateless2/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_stateless2/joiner.py b/egs/librispeech/ASR/transducer_stateless2/joiner.py new file mode 100644 index 000000000..765f0be8b --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/joiner.py @@ -0,0 +1,67 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class Joiner(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.output_linear = nn.Linear(input_dim, output_dim) + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + *unused, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, self.input_dim). + decoder_out: + Output from the decoder. Its shape is (N, U, self.input_dim). + unused: + This is a placeholder so that we can reuse + transducer_stateless/beam_search.py in this folder as that + script assumes the joiner networks accepts 4 inputs. + Returns: + Return a tensor of shape (N, T, U, self.output_dim). + """ + assert encoder_out.ndim == decoder_out.ndim == 3 + assert encoder_out.size(0) == decoder_out.size(0) + assert encoder_out.size(2) == self.input_dim + assert decoder_out.size(2) == self.input_dim + + encoder_out = encoder_out.unsqueeze(2) # (N, T, 1, C) + decoder_out = decoder_out.unsqueeze(1) # (N, 1, U, C) + x = encoder_out + decoder_out # (N, T, U, C) + + activations = torch.tanh(x) + + logits = self.output_linear(activations) + + if not self.training: + # We reuse the beam_search.py from transducer_stateless, + # which expects that the joiner network outputs + # a 2-D tensor. + logits = logits.squeeze(2).squeeze(1) + + return logits diff --git a/egs/librispeech/ASR/transducer_stateless2/model.py b/egs/librispeech/ASR/transducer_stateless2/model.py new file mode 100644 index 000000000..d04716706 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/model.py @@ -0,0 +1,130 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note we use `rnnt_loss` from torchaudio, which exists only in +torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 +""" + + +import k2 +import torch +import torch.nn as nn +import torchaudio +import torchaudio.functional +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, C) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, C). It should contain + one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, C) and (N, U, C). Its + output shape is (N, T, U, C). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + Returns: + Return the transducer loss. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) + + decoder_out = self.decoder(sos_y_padded) + + logits = self.joiner( + encoder_out=encoder_out, + decoder_out=decoder_out, + ) + + # rnnt_loss requires 0 padded targets + # Note: y does not start with SOS + y_padded = y.pad(mode="constant", padding_value=0) + + assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" + ) + + loss = torchaudio.functional.rnnt_loss( + logits=logits, + targets=y_padded, + logit_lengths=x_lens, + target_lengths=y_lens, + blank=blank_id, + reduction="sum", + ) + + return loss diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py new file mode 100755 index 000000000..2f0604893 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./transducer_stateless2/pretrained.py \ + --checkpoint ./transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + --max-sym-per-frame 1 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(2) beam search +./transducer_stateless2/pretrained.py \ + --checkpoint ./transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(3) modified beam search +./transducer_stateless2/pretrained.py \ + --checkpoint ./transducer_stateless2/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +You can also use `./transducer_stateless2/exp/epoch-xx.pt`. + +Note: ./transducer_stateless2/exp/pretrained.pt is generated by +./transducer_stateless2/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search and modified_beam_search ", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyp_list = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer_stateless2/subsampling.py b/egs/librispeech/ASR/transducer_stateless2/subsampling.py new file mode 120000 index 000000000..af74db6e3 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/subsampling.py @@ -0,0 +1 @@ +../transducer_stateless/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py new file mode 100755 index 000000000..8ceffb489 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -0,0 +1,779 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 250 \ + --lr-factor 2.5 +""" + + +import argparse +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from model import Transducer +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + transducer_stateless2/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_stateless2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - attention_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + # parameters for Noam + "warm_step": 80000, # For the 100h subset, use 8k + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + loss = model( + x=feature, + x_lens=feature_lens, + y=y, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + if params.print_diagnostics and batch_idx == 5: + return + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + params.warm_step = 8000 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + return 1.0 <= c.duration <= 20.0 + + num_in_total = len(train_cuts) + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + try: + num_left = len(train_cuts) + num_removed = num_in_total - num_left + removed_percent = num_removed / num_in_total * 100 + + logging.info( + f"Before removing short and long utterances: {num_in_total}" + ) + logging.info(f"After removing short and long utterances: {num_left}") + logging.info( + f"Removed {num_removed} utterances ({removed_percent:.5f}%)" + ) + except TypeError as e: + # You can ignore this error as previous versions of Lhotse work fine + # for the above code. In recent versions of Lhotse, it uses + # lazy filter, producing cutsets that don't have the __len__ method + logging.info(str(e)) + + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless2/transformer.py b/egs/librispeech/ASR/transducer_stateless2/transformer.py new file mode 120000 index 000000000..e43f520f9 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless2/transformer.py @@ -0,0 +1 @@ +../transducer_stateless/transformer.py \ No newline at end of file From 5228b44de7c5b3982e5f7d6ca4ffec6b2fb3c5fe Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 19 Apr 2022 22:00:47 +0800 Subject: [PATCH 227/234] Support modified beam search decoding for streaming inference with Emformer model. --- .../transducer_emformer/streaming_decode.py | 197 +++++++++++++++--- .../streaming_feature_extractor.py | 53 +++-- 2 files changed, 210 insertions(+), 40 deletions(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index bb71310b7..f5e24a0d9 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -18,16 +18,23 @@ import argparse import logging +import warnings from pathlib import Path from typing import List, Optional, Tuple +import k2 import numpy as np import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, _get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states -from streaming_feature_extractor import FeatureExtractionStream +from streaming_feature_extractor import ( + FeatureExtractionStream, + GreedySearchStream, + ModifiedBeamSearchStream, +) from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -50,6 +57,7 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( "--avg", type=int, @@ -208,7 +216,7 @@ class StreamList(object): self, batch_size: int, context_size: int, - blank_id: int, + decoding_method: str, ): """ Args: @@ -216,14 +224,21 @@ class StreamList(object): Size of this batch. context_size: Context size of the RNN-T decoder model. - blank_id: - The ID of the blank symbol of the BPE model. + decoding_method: + Decoding method. The possible values are: + - greedy_search + - modified_beam_search """ + decoding_classes = { + "greedy_search": GreedySearchStream, + "modified_beam_search": ModifiedBeamSearchStream, + } + + assert decoding_method in decoding_classes + cls = decoding_classes[decoding_method] + self.streams = [ - FeatureExtractionStream( - context_size=context_size, blank_id=blank_id - ) - for _ in range(batch_size) + cls(context_size=context_size) for _ in range(batch_size) ] @property @@ -238,7 +253,7 @@ class StreamList(object): audio_samples: List[torch.Tensor], sampling_rate: float, ): - """Feeed audio samples to each stream. + """Feed audio samples to each stream. Args: audio_samples: A list of 1-D tensors containing the audio samples for each @@ -314,7 +329,7 @@ class StreamList(object): def greedy_search( model: nn.Module, - streams: List[FeatureExtractionStream], + streams: List[GreedySearchStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): @@ -333,7 +348,15 @@ def greedy_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size device = model.device + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + for s in streams: + if s.hyp is None: + s.hyp = Hypothesis( + ys=([blank_id] * context_size), + log_prob=torch.tensor([0.0], device=device), + ) if streams[0].decoder_out is None: decoder_input = torch.tensor( [stream.hyp.ys[-context_size:] for stream in streams], @@ -351,8 +374,6 @@ def greedy_search( dim=0, ) - assert encoder_out.ndim == 3 - T = encoder_out.size(1) for t in range(T): current_encoder_out = encoder_out[:, t] @@ -381,20 +402,132 @@ def greedy_search( ) for k, s in enumerate(streams): - logging.info( - f"Partial result {k}:\n{sp.decode(s.hyp.ys[context_size:])}" - ) + logging.info(f"Partial result {k}:\n{sp.decode(s.result)}") decoder_out_list = decoder_out.unbind(dim=0) - for i, d in enumerate(decoder_out_list): streams[i].decoder_out = d +def modified_beam_search( + model: nn.Module, + streams: List[ModifiedBeamSearchStream], + encoder_out: torch.Tensor, + sp: spm.SentencePieceProcessor, + beam: int = 4, +): + """ + Args: + model: + The RNN-T model. + stream: + A stream object. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + sp: + The BPE model. + beam: + Number of active paths during the beam search. + """ + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + batch_size = len(streams) + + for s in streams: + if len(s.hyps) == 0: + s.hyps.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + B = [s.hyps for s in streams] + + T = encoder_out.size(1) + for t in range(T): + current_encoder_out = encoder_out[:, t] + # current_encoder_out's shape: (batch_size, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) + # decoder_out is of shape (num_hyps, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner(current_encoder_out, decoder_out) + # logits is of shape (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + streams[i].hyps = B[i] + logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}") + + def process_features( model: nn.Module, features: torch.Tensor, streams: List[FeatureExtractionStream], + params: AttributeDict, sp: spm.SentencePieceProcessor, ) -> None: """Process features for each stream in parallel. @@ -406,6 +539,8 @@ def process_features( A 3-D tensor of shape (N, T, C). streams: A list of streams of size (N,). + params: + It is the return value of :func:`get_params`. sp: The BPE model. """ @@ -439,12 +574,25 @@ def process_features( for i, s in enumerate(state_list): streams[i].states = s - greedy_search( - model=model, - streams=streams, - encoder_out=encoder_out, - sp=sp, - ) + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + sp=sp, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + sp=sp, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) def decode_batch( @@ -479,7 +627,7 @@ def decode_batch( stream_list = StreamList( batch_size=batch_size, context_size=params.context_size, - blank_id=params.blank_id, + decoding_method=params.decoding_method, ) while not streaming_audio_samples.done: @@ -497,11 +645,12 @@ def decode_batch( model=model, features=features, streams=active_streams, + params=params, sp=sp, ) results = [] for s in stream_list.streams: - text = sp.decode(s.hyp.ys[params.context_size :]) + text = sp.decode(s.result) results.append(text) return results diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index b20f6502f..a040cc09c 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -17,7 +17,7 @@ from typing import List, Optional import torch -from beam_search import Hypothesis +from beam_search import Hypothesis, HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature @@ -41,14 +41,10 @@ def _create_streaming_feature_extractor() -> OnlineFeature: class FeatureExtractionStream(object): - def __init__(self, context_size: int, blank_id: int = 0) -> None: - """Context size of the RNN-T decoder model.""" + def __init__( + self, + ) -> None: self.feature_extractor = _create_streaming_feature_extractor() - self.hyp = Hypothesis( - ys=([blank_id] * context_size), - log_prob=torch.tensor([0.0]), - ) # for greedy search, will extend it to beam search - # It contains a list of 1-D tensors representing the feature frames. self.feature_frames: List[torch.Tensor] = [] @@ -58,11 +54,6 @@ class FeatureExtractionStream(object): # encoder layer. self.states: Optional[List[List[torch.Tensor]]] = None - # For the RNN-T decoder, it contains the decoder output - # corresponding to the decoder input self.hyp.ys[-context_size:] - # Its shape is (decoder_out_dim,) - self.decoder_out: Optional[torch.Tensor] = None - # After calling `self.input_finished()`, we set this flag to True self._done = False @@ -85,9 +76,9 @@ class FeatureExtractionStream(object): check to ensure that the input sampling rate equals to the one used in the extractor. If they are not equal, then no resampling will be performed; instead an error will be thrown. - waveform: - A 1-D torch tensor of dtype torch.float32 containing audio samples. - It should be on CPU. + waveform: + A 1-D torch tensor of dtype torch.float32 containing audio samples. + It should be on CPU. """ self.feature_extractor.accept_waveform( sampling_rate=sampling_rate, @@ -114,3 +105,33 @@ class FeatureExtractionStream(object): frame = self.feature_extractor.get_frame(self.num_fetched_frames) self.feature_frames.append(frame) self.num_fetched_frames += 1 + + +class GreedySearchStream(FeatureExtractionStream): + def __init__(self, context_size: int) -> None: + """FeatureExtractionStream class for greedy search.""" + super().__init__() + self.context_size = context_size + # For the RNN-T decoder, it contains the decoder output + # corresponding to the decoder input self.hyp.ys[-context_size:] + # Its shape is (decoder_out_dim,) + self.hyp: Hypothesis = None + self.decoder_out: Optional[torch.Tensor] = None + + @property + def result(self) -> List[int]: + return self.hyp.ys[self.context_size :] + + +class ModifiedBeamSearchStream(FeatureExtractionStream): + def __init__(self, context_size: int) -> None: + """FeatureExtractionStream class for modified beam search decoding.""" + super().__init__() + self.context_size = context_size + self.hyps = HypothesisList() + self.best_hyp = None + + @property + def result(self) -> List[int]: + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] From e74654c2a242677bcdc1481b1864515ed52a6f52 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 19 Apr 2022 22:05:14 +0800 Subject: [PATCH 228/234] Formatted imports. --- egs/librispeech/ASR/transducer_emformer/streaming_decode.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index f5e24a0d9..df3303100 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -37,11 +37,7 @@ from streaming_feature_extractor import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, setup_logger From 3607c516d6512295ff50c7ba1ce9ca8c231f5782 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 20 Apr 2022 11:15:10 +0800 Subject: [PATCH 229/234] Update results for torchaudio RNN-T. (#322) --- egs/librispeech/ASR/RESULTS.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ac8b3ec75..2fed7feed 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -7,7 +7,7 @@ This is with a reworked version of the conformer encoder, with many changes. #### Training on fulll librispeech -using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`. +Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`. See The WERs are: @@ -353,8 +353,12 @@ You can find a pretrained model by visiting ##### 2022-04-19 [transducer_stateless2](./transducer_stateless2) + This version uses torchaudio's RNN-T loss. +Using commit `fce7f3cd9a486405ee008bcbe4999264f27774a3`. +See + | | test-clean | test-other | comment | |-------------------------------------|------------|------------|--------------------------------------------------------------------------------| | greedy search (max sym per frame 1) | 2.65 | 6.30 | --epoch 59 --avg 10 --max-duration 600 | From 24db3a19347892a7467eee2b3d80d21cc2519d4b Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 20 Apr 2022 14:21:45 +0800 Subject: [PATCH 230/234] update emformer_pruned_transducer_stateless/emformer.py --- .../ASR/emformer_pruned_transducer_stateless/emformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index b6f93b4c7..9973d6a15 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -1280,6 +1280,7 @@ class Emformer(EncoderInterface): self.subsampling_factor = subsampling_factor self.right_context_length = right_context_length + self.chunk_length = chunk_length if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") if chunk_length % 4 != 0: From cf0ce8db322e48b2148c1e6aee59801391b620a8 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 21 Apr 2022 19:48:35 +0800 Subject: [PATCH 231/234] Fixed streaming decoding codes for emformer model. --- .../beam_search.py | 4 +- .../transducer_emformer/streaming_decode.py | 105 ++++++++---------- .../streaming_feature_extractor.py | 63 ++++------- 3 files changed, 74 insertions(+), 98 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 2cb7a8cba..574c637ae 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -367,7 +367,7 @@ class HypothesisList(object): return ", ".join(s) -def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: """Return a ragged shape with axes [utt][num_hyps]. Args: @@ -431,7 +431,7 @@ def modified_beam_search( current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - hyps_shape = _get_hyps_shape(B).to(device) + hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index df3303100..c5bcb3aee 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -28,16 +28,16 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import Hypothesis, HypothesisList, _get_hyps_shape +from beam_search import Hypothesis, HypothesisList, get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states -from streaming_feature_extractor import ( - FeatureExtractionStream, - GreedySearchStream, - ModifiedBeamSearchStream, -) +from streaming_feature_extractor import FeatureExtractionStream from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import AttributeDict, setup_logger @@ -225,16 +225,12 @@ class StreamList(object): - greedy_search - modified_beam_search """ - decoding_classes = { - "greedy_search": GreedySearchStream, - "modified_beam_search": ModifiedBeamSearchStream, - } - - assert decoding_method in decoding_classes - cls = decoding_classes[decoding_method] self.streams = [ - cls(context_size=context_size) for _ in range(batch_size) + FeatureExtractionStream( + context_size=context_size, decoding_method=decoding_method + ) + for _ in range(batch_size) ] @property @@ -325,7 +321,7 @@ class StreamList(object): def greedy_search( model: nn.Module, - streams: List[GreedySearchStream], + streams: List[FeatureExtractionStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, ): @@ -333,36 +329,31 @@ def greedy_search( Args: model: The RNN-T model. - stream: - A stream object. + streams: + A list of GreedySearchDecodingStream objects. encoder_out: A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of the encoder model. sp: The BPE model. """ - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device assert len(streams) == encoder_out.size(0) assert encoder_out.ndim == 3 - for s in streams: - if s.hyp is None: - s.hyp = Hypothesis( - ys=([blank_id] * context_size), - log_prob=torch.tensor([0.0], device=device), - ) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + if streams[0].decoder_out is None: + for stream in streams: + stream.hyp = [blank_id] * context_size decoder_input = torch.tensor( - [stream.hyp.ys[-context_size:] for stream in streams], + [stream.hyp[-context_size:] for stream in streams], device=device, dtype=torch.int64, ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ).squeeze(1) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) # decoder_out is of shape (N, decoder_out_dim) else: decoder_out = torch.stack( @@ -370,7 +361,6 @@ def greedy_search( dim=0, ) - T = encoder_out.size(1) for t in range(T): current_encoder_out = encoder_out[:, t] # current_encoder_out's shape: (batch_size, encoder_out_dim) @@ -383,22 +373,23 @@ def greedy_search( emitted = False for i, v in enumerate(y): if v != blank_id: - streams[i].hyp.ys.append(v) + streams[i].hyp.append(v) emitted = True - if emitted: # update decoder output decoder_input = torch.tensor( - [stream.hyp.ys[-context_size:] for stream in streams], + [stream.hyp[-context_size:] for stream in streams], device=device, dtype=torch.int64, ) - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze( - 1 - ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ).squeeze(1) - for k, s in enumerate(streams): - logging.info(f"Partial result {k}:\n{sp.decode(s.result)}") + for k, stream in enumerate(streams): + result = sp.decode(stream.decoding_result()) + logging.info(f"Partial result {k}:\n{result}") decoder_out_list = decoder_out.unbind(dim=0) for i, d in enumerate(decoder_out_list): @@ -407,7 +398,7 @@ def greedy_search( def modified_beam_search( model: nn.Module, - streams: List[ModifiedBeamSearchStream], + streams: List[FeatureExtractionStream], encoder_out: torch.Tensor, sp: spm.SentencePieceProcessor, beam: int = 4, @@ -426,36 +417,35 @@ def modified_beam_search( beam: Number of active paths during the beam search. """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + blank_id = model.decoder.blank_id context_size = model.decoder.context_size device = model.device - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) batch_size = len(streams) + T = encoder_out.size(1) - for s in streams: - if len(s.hyps) == 0: - s.hyps.add( + for stream in streams: + if len(stream.hyps) == 0: + stream.hyps.add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), ) ) - - B = [s.hyps for s in streams] - - T = encoder_out.size(1) + B = [stream.hyps for stream in streams] for t in range(T): current_encoder_out = encoder_out[:, t] # current_encoder_out's shape: (batch_size, encoder_out_dim) - hyps_shape = _get_hyps_shape(B).to(device) + hyps_shape = get_hyps_shape(B).to(device) A = [list(b) for b in B] B = [HypothesisList() for _ in range(batch_size)] - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 ) # (num_hyps, 1) decoder_input = torch.tensor( @@ -516,7 +506,8 @@ def modified_beam_search( B[i].add(new_hyp) streams[i].hyps = B[i] - logging.info(f"Partial result {i}:\n{sp.decode(streams[i].result)}") + result = sp.decode(streams[i].decoding_result()) + logging.info(f"Partial result {i}:\n{result}") def process_features( @@ -645,8 +636,8 @@ def decode_batch( sp=sp, ) results = [] - for s in stream_list.streams: - text = sp.decode(s.result) + for stream in stream_list.streams: + text = sp.decode(stream.decoding_result()) results.append(text) return results diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index a040cc09c..c3d9a5675 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional - -import torch -from beam_search import Hypothesis, HypothesisList +from beam_search import HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from typing import List, Optional +import torch def _create_streaming_feature_extractor() -> OnlineFeature: @@ -41,21 +40,28 @@ def _create_streaming_feature_extractor() -> OnlineFeature: class FeatureExtractionStream(object): - def __init__( - self, - ) -> None: + def __init__(self, context_size: int, decoding_method: str) -> None: self.feature_extractor = _create_streaming_feature_extractor() # It contains a list of 1-D tensors representing the feature frames. self.feature_frames: List[torch.Tensor] = [] - self.num_fetched_frames = 0 + # After calling `self.input_finished()`, we set this flag to True + self._done = False # For the emformer model, it contains the states of each # encoder layer. self.states: Optional[List[List[torch.Tensor]]] = None - # After calling `self.input_finished()`, we set this flag to True - self._done = False + # It use different attributes for different decoding methods. + self.context_size = context_size + self.decoding_method = decoding_method + if decoding_method == "greedy_search": + self.hyp: List[int] = None + self.decoder_out: Optional[torch.Tensor] = None + elif decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + else: + raise ValueError(f"Unsupported decoding method: {decoding_method}") def accept_waveform( self, @@ -106,32 +112,11 @@ class FeatureExtractionStream(object): self.feature_frames.append(frame) self.num_fetched_frames += 1 - -class GreedySearchStream(FeatureExtractionStream): - def __init__(self, context_size: int) -> None: - """FeatureExtractionStream class for greedy search.""" - super().__init__() - self.context_size = context_size - # For the RNN-T decoder, it contains the decoder output - # corresponding to the decoder input self.hyp.ys[-context_size:] - # Its shape is (decoder_out_dim,) - self.hyp: Hypothesis = None - self.decoder_out: Optional[torch.Tensor] = None - - @property - def result(self) -> List[int]: - return self.hyp.ys[self.context_size :] - - -class ModifiedBeamSearchStream(FeatureExtractionStream): - def __init__(self, context_size: int) -> None: - """FeatureExtractionStream class for modified beam search decoding.""" - super().__init__() - self.context_size = context_size - self.hyps = HypothesisList() - self.best_hyp = None - - @property - def result(self) -> List[int]: - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.context_size :] + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.decoding_method == "greedy_search": + return self.hyp[self.context_size :] + else: + assert self.decoding_method == "modified_beam_search" + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] From d20a852f61229f340c9d2d7157fbc13a02222533 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 21 Apr 2022 19:55:30 +0800 Subject: [PATCH 232/234] Fixed docs. --- egs/librispeech/ASR/transducer_emformer/streaming_decode.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index c5bcb3aee..8ebfbb210 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -330,7 +330,7 @@ def greedy_search( model: The RNN-T model. streams: - A list of GreedySearchDecodingStream objects. + A list of stream objects. encoder_out: A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of the encoder model. @@ -407,8 +407,8 @@ def modified_beam_search( Args: model: The RNN-T model. - stream: - A stream object. + streams: + A list of stream objects. encoder_out: A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of the encoder model. From e97c9fbdbf47d04a16ce0a8afe007d921c7dfeab Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 22 Apr 2022 11:04:50 +0800 Subject: [PATCH 233/234] Sorted imports for transducer_emformer/streaming_feature_extractor.py --- .../streaming_feature_extractor.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index c3d9a5675..4d405cad1 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + +import torch from beam_search import HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from typing import List, Optional -import torch def _create_streaming_feature_extractor() -> OnlineFeature: @@ -41,6 +42,15 @@ def _create_streaming_feature_extractor() -> OnlineFeature: class FeatureExtractionStream(object): def __init__(self, context_size: int, decoding_method: str) -> None: + """ + Args: + context_size: + Context size of the RNN-T decoder model. + decoding_method: + Decoding method. The possible values are: + - greedy_search + - modified_beam_search + """ self.feature_extractor = _create_streaming_feature_extractor() # It contains a list of 1-D tensors representing the feature frames. self.feature_frames: List[torch.Tensor] = [] From ece99a862b0f5333de79b98be70e8bcaa06343d2 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 22 Apr 2022 11:23:23 +0800 Subject: [PATCH 234/234] Minor fix for transducer_emformer/streaming_feature_extractor.py --- .../ASR/transducer_emformer/streaming_feature_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index 4d405cad1..ea323103b 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -66,7 +66,7 @@ class FeatureExtractionStream(object): self.context_size = context_size self.decoding_method = decoding_method if decoding_method == "greedy_search": - self.hyp: List[int] = None + self.hyp: Optional[List[int]] = None self.decoder_out: Optional[torch.Tensor] = None elif decoding_method == "modified_beam_search": self.hyps = HypothesisList()