diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 00e906691..a909b2a74 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -116,7 +116,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -302,6 +302,8 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + add_model_arguments(parser) + return parser @@ -354,13 +356,6 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - # feature_lens += params.left_context - # feature = torch.nn.functional.pad( - # feature, - # pad=(0, 0, 0, params.left_context), - # value=LOG_EPS, - # ) - encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 47b2c7b2b..45409ccea 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -15,7 +15,7 @@ # limitations under the License. import copy -from typing import Tuple +from typing import List, Optional, Tuple import torch from encoder_interface import EncoderInterface @@ -47,6 +47,9 @@ class RNN(EncoderInterface): Dropout rate (default=0.1). layer_dropout (float): Dropout value for model-level warmup (default=0.075). + aux_layer_period (int): + Peroid of auxiliary layers used for randomly combined during training. + If not larger than 0, will not use the random combiner. """ def __init__( @@ -58,6 +61,7 @@ class RNN(EncoderInterface): num_encoder_layers: int = 12, dropout: float = 0.1, layer_dropout: float = 0.075, + aux_layer_period: int = 3, ) -> None: super(RNN, self).__init__() @@ -79,7 +83,19 @@ class RNN(EncoderInterface): encoder_layer = RNNEncoderLayer( d_model, dim_feedforward, dropout, layer_dropout ) - self.encoder = RNNEncoder(encoder_layer, num_encoder_layers) + self.encoder = RNNEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ) + if aux_layer_period > 0 + else None, + ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -306,13 +322,31 @@ class RNNEncoder(nn.Module): The number of sub-encoder-layers in the encoder (required). """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: Optional[List[int]] = None, + ) -> None: super(RNNEncoder, self).__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.use_random_combiner = False + if aux_layers is not None: + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers + self.use_random_combiner = True + self.aux_layers = aux_layers + [num_layers - 1] + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: """ Pass the input through the encoder layer in turn. @@ -328,8 +362,16 @@ class RNNEncoder(nn.Module): """ output = src - for layer_index, mod in enumerate(self.layers): + outputs = [] + + for i, mod in enumerate(self.layers): output = mod(output, warmup=warmup) + if self.use_random_combiner: + if i in self.aux_layers: + outputs.append(output) + + if self.use_random_combiner: + output = self.combiner(outputs) return output @@ -459,6 +501,244 @@ class Conv2dSubsampling(nn.Module): return x +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting(): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + if __name__ == "__main__": feature_dim = 50 m = RNN(num_features=feature_dim, d_model=128) @@ -470,3 +750,5 @@ if __name__ == "__main__": torch.full((batch_size,), seq_len, dtype=torch.int64), warmup=0.5, ) + + _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index bf48b231b..79b0a45a2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -42,7 +42,7 @@ from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -177,6 +177,8 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + add_model_arguments(parser) + return parser @@ -434,9 +436,7 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state( - params.left_context, device=device - ) + initial_states = model.encoder.get_init_states(device=device) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 738a880eb..0826c72e9 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -57,12 +57,12 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from lstm import RNN from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -86,6 +86,24 @@ LRSchedulerType = Union[ ] +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=20, + help="Number of RNN encoder layers..", + ) + + parser.add_argument( + "--aux-layer-period", + type=int, + default=3, + help="""Peroid of auxiliary layers used for randomly combined during training. + If not larger than 0, will not use the random combiner. + """, + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -279,6 +297,8 @@ def get_parser(): help="Whether to use half precision training.", ) + add_model_arguments(parser) + return parser @@ -341,7 +361,6 @@ def get_params() -> AttributeDict: "subsampling_factor": 4, "encoder_dim": 512, "dim_feedforward": 2048, - "num_encoder_layers": 12, # parameters for decoder "decoder_dim": 512, # parameters for joiner @@ -363,6 +382,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: d_model=params.encoder_dim, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + aux_layer_period=params.aux_layer_period, ) return encoder