From d58002c4146e24d3e19b3ece0ce90ef29c32bdff Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 8 Apr 2022 20:31:32 +0800 Subject: [PATCH] update emformer.py --- .pre-commit-config.yaml | 1 + .../emformer.py | 406 ++++++++++-------- 2 files changed, 233 insertions(+), 174 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b59784dbf..62d34864b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: hooks: - id: black args: [--line-length=80] + additional_dependencies: ['click==8.0.1'] - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 diff --git a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py index 91bb571c5..4ba19ebae 100644 --- a/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/emformer_pruned_transducer_stateless/emformer.py @@ -1,3 +1,22 @@ +# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# It is modified based on +# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. + import math import warnings from typing import List, Optional, Tuple @@ -22,29 +41,32 @@ def _get_activation_module(activation: str) -> nn.Module: def _get_weight_init_gains( - weight_init_scale_strategy: Optional[str], - num_layers: int + weight_init_scale_strategy: Optional[str], num_layers: int ) -> List[Optional[float]]: if weight_init_scale_strategy is None: return [None for _ in range(num_layers)] elif weight_init_scale_strategy == "depthwise": - return [1.0 / math.sqrt(layer_idx + 1) - for layer_idx in range(num_layers)] + return [ + 1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers) + ] elif weight_init_scale_strategy == "constant": return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] else: - raise ValueError(f"Unsupported weight_init_scale_strategy value" - f"{weight_init_scale_strategy}") + raise ValueError( + f"Unsupported weight_init_scale_strategy value" + f"{weight_init_scale_strategy}" + ) def _gen_attention_mask_block( col_widths: List[int], col_mask: List[bool], num_rows: int, - device: torch.device + device: torch.device, ) -> torch.Tensor: - assert len(col_widths) == len(col_mask), ( - "Length of col_widths must match that of col_mask") + assert len(col_widths) == len( + col_mask + ), "Length of col_widths must match that of col_mask" mask_block = [ torch.ones(num_rows, col_width, device=device) @@ -99,9 +121,7 @@ class EmformerAttention(nn.Module): self.scaling = (self.embed_dim // self.nhead) ** -0.5 - self.emb_to_key_value = nn.Linear( - embed_dim, 2 * embed_dim, bias=True - ) + self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True) self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) @@ -119,7 +139,7 @@ class EmformerAttention(nn.Module): attention_mask: torch.Tensor, padding_mask: Optional[torch.Tensor], ) -> torch.Tensor: - """ Given the entire attention weights, mask out unecessary connections + """Given the entire attention weights, mask out unecessary connections and optionally with padding positions, to obtain underlying chunk-wise attention probabilities. @@ -154,7 +174,7 @@ class EmformerAttention(nn.Module): ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), - self.negative_inf + self.negative_inf, ) attention_weights_float = attention_weights_float.view( B * self.nhead, Q, -1 @@ -164,9 +184,7 @@ class EmformerAttention(nn.Module): attention_weights_float, dim=-1 ).type_as(attention_weights) attention_probs = nn.functional.dropout( - attention_probs, - p=float(self.dropout), - training=self.training + attention_probs, p=float(self.dropout), training=self.training ) return attention_probs @@ -181,7 +199,7 @@ class EmformerAttention(nn.Module): left_context_key: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ Underlying chunk-wise attention implementation. + """Underlying chunk-wise attention implementation. L: length of left_context; S: length of summary; @@ -242,14 +260,28 @@ class EmformerAttention(nn.Module): # [mems, right context, left context, uttrance] M = memory.size(0) R = right_context.size(0) - key = torch.cat([key[:M + R], left_context_key, key[M + R:]]) - value = torch.cat([value[:M + R], left_context_val, value[M + R:]]) + right_context_end_idx = M + R + key = torch.cat( + [ + key[:right_context_end_idx], + left_context_key, + key[right_context_end_idx:], + ] + ) + value = torch.cat( + [ + value[:right_context_end_idx], + left_context_val, + value[right_context_end_idx:], + ] + ) # Compute attention weights from query, key, and value. reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view( - -1, B * self.nhead, self.embed_dim // self.nhead - ).transpose(0, 1) for tensor in [query, key, value] + tensor.contiguous() + .view(-1, B * self.nhead, self.embed_dim // self.nhead) + .transpose(0, 1) + for tensor in [query, key, value] ] attention_weights = torch.bmm( reshaped_query * self.scaling, reshaped_key.transpose(1, 2) @@ -272,18 +304,21 @@ class EmformerAttention(nn.Module): attention = torch.bmm(attention_probs, reshaped_value) Q = query.size(0) assert attention.shape == ( - B * self.nhead, Q, self.embed_dim // self.nhead, + B * self.nhead, + Q, + self.embed_dim // self.nhead, ) - attention = attention.transpose(0, 1).contiguous().view( - Q, B, self.embed_dim + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) ) # Apply output projection. outputs = self.out_proj(attention) S = summary.size(0) - output_right_context_utterance = outputs[:Q - S] - output_memory = outputs[Q - S:] + summary_start_idx = Q - S + output_right_context_utterance = outputs[:summary_start_idx] + output_memory = outputs[summary_start_idx:] if self.tanh_on_mem: output_memory = torch.tanh(output_memory) else: @@ -331,15 +366,14 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - output_right_context_utterance, output_memory, _, _ = \ - self._forward_impl( - utterance, - lengths, - right_context, - summary, - memory, - attention_mask - ) + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( + utterance, lengths, right_context, summary, memory, attention_mask + ) return output_right_context_utterance, output_memory[:-1] @torch.jit.export @@ -394,29 +428,38 @@ class EmformerAttention(nn.Module): # query: [right context, utterance, summary] Q = right_context.size(0) + utterance.size(0) + summary.size(0) # key, value: [memory, right context, left context, uttrance] - KV = memory.size(0) + right_context.size(0) + \ - left_context_key.size(0) + utterance.size(0) - attention_mask = torch.zeros( - Q, KV - ).to(dtype=torch.bool, device=utterance.device) + KV = ( + memory.size(0) + + right_context.size(0) + + left_context_key.size(0) + + utterance.size(0) + ) + attention_mask = torch.zeros(Q, KV).to( + dtype=torch.bool, device=utterance.device + ) # Disallow attention bettween the summary vector with the memory bank - attention_mask[-1, :memory.size(0)] = True - output_right_context_utterance, output_memory, key, value = \ - self._forward_impl( - utterance, - lengths, - right_context, - summary, - memory, - attention_mask, - left_context_key=left_context_key, - left_context_val=left_context_val, - ) + attention_mask[-1, : memory.size(0)] = True + ( + output_right_context_utterance, + output_memory, + key, + value, + ) = self._forward_impl( + utterance, + lengths, + right_context, + summary, + memory, + attention_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + right_context_end_idx = memory.size(0) + right_context.size(0) return ( output_right_context_utterance, output_memory, - key[memory.size(0) + right_context.size(0):], - value[memory.size(0) + right_context.size(0):], + key[right_context_end_idx:], + value[right_context_end_idx:], ) @@ -499,9 +542,7 @@ class EmformerLayer(nn.Module): self.use_memory = max_memory_size > 0 def _init_state( - self, - batch_size: int, - device: Optional[torch.device] + self, batch_size: int, device: Optional[torch.device] ) -> List[torch.Tensor]: """Initialize states with zeros.""" empty_memory = torch.zeros( @@ -519,8 +560,7 @@ class EmformerLayer(nn.Module): return [empty_memory, left_context_key, left_context_val, past_length] def _unpack_state( - self, - state: List[torch.Tensor] + self, state: List[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Unpack cached states including: 1) output memory from previous chunks in the lower layer; @@ -532,11 +572,13 @@ class EmformerLayer(nn.Module): past_memory_length = min( self.max_memory_size, math.ceil(past_length / self.chunk_length) ) - pre_memory = state[0][self.max_memory_size - past_memory_length:] - left_context_key = \ - state[1][self.left_context_length - past_left_context_length:] - left_context_val = \ - state[2][self.left_context_length - past_left_context_length:] + memory_start_idx = self.max_memory_size - past_memory_length + pre_memory = state[0][memory_start_idx:] + left_context_start_idx = ( + self.left_context_length - past_left_context_length + ) + left_context_key = state[1][left_context_start_idx:] + left_context_val = state[2][left_context_start_idx:] return pre_memory, left_context_key, left_context_val def _pack_state( @@ -556,40 +598,46 @@ class EmformerLayer(nn.Module): new_memory = torch.cat([state[0], memory]) new_key = torch.cat([state[1], next_key]) new_val = torch.cat([state[2], next_val]) - state[0] = new_memory[new_memory.size(0) - self.max_memory_size:] - state[1] = new_key[new_key.size(0) - self.left_context_length:] - state[2] = new_val[new_val.size(0) - self.left_context_length:] + memory_start_idx = new_memory.size(0) - self.max_memory_size + state[0] = new_memory[memory_start_idx:] + key_start_idx = new_key.size(0) - self.left_context_length + state[1] = new_key[key_start_idx:] + val_start_idx = new_val.size(0) - self.left_context_length + state[2] = new_val[val_start_idx:] state[3] = state[3] + update_length return state def _apply_pre_attention_layer_norm( self, utterance: torch.Tensor, right_context: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply layer normalization before attention. """ + """Apply layer normalization before attention.""" layer_norm_input = self.layer_norm_input( torch.cat([right_context, utterance]) ) - layer_norm_utterance = layer_norm_input[right_context.size(0):] - layer_norm_right_context = layer_norm_input[:right_context.size(0)] + right_context_end_idx = right_context.size(0) + layer_norm_utterance = layer_norm_input[right_context_end_idx:] + layer_norm_right_context = layer_norm_input[:right_context_end_idx] return layer_norm_utterance, layer_norm_right_context def _apply_post_attention_ffn_layer_norm( self, output_right_context_utterance: torch.Tensor, utterance: torch.Tensor, - right_context: torch.Tensor + right_context: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply feed forward and layer normalization after attention.""" # Apply residual connection between input and attention output. - result = self.dropout(output_right_context_utterance) + \ - torch.cat([right_context, utterance]) + result = self.dropout(output_right_context_utterance) + torch.cat( + [right_context, utterance] + ) # Apply feedforward module and residual connection. result = self.pos_ff(result) + result # Apply layer normalization for output. result = self.layer_norm_output(result) - output_utterance = result[right_context.size(0):] - output_right_context = result[:right_context.size(0)] + right_context_end_idx = right_context.size(0) + output_utterance = result[right_context_end_idx:] + output_right_context = result[:right_context_end_idx] return output_utterance, output_right_context def _apply_attention_forward( @@ -600,16 +648,16 @@ class EmformerLayer(nn.Module): memory: torch.Tensor, attention_mask: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply attention in non-infer mode. """ + """Apply attention in non-infer mode.""" if attention_mask is None: raise ValueError( "attention_mask must be not None in non-infer mode. " ) if self.use_memory: - summary = self.summary_op( - utterance.permute(1, 2, 0) - ).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) else: summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device @@ -646,27 +694,32 @@ class EmformerLayer(nn.Module): """ if state is None: state = self._init_state(utterance.size(1), device=utterance.device) - pre_memory, left_context_key, left_context_val = \ - self._unpack_state(state) + pre_memory, left_context_key, left_context_val = self._unpack_state( + state + ) if self.use_memory: - summary = self.summary_op( - utterance.permute(1, 2, 0) - ).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) summary = summary[:1] else: summary = torch.empty(0).to( dtype=utterance.dtype, device=utterance.device ) - output_right_context_utterance, output_memory, next_key, next_val = \ - self.attention.infer( - utterance=utterance, - lengths=lengths, - right_context=right_context, - summary=summary, - memory=pre_memory, - left_context_key=left_context_key, - left_context_val=left_context_val, - ) + ( + output_right_context_utterance, + output_memory, + next_key, + next_val, + ) = self.attention.infer( + utterance=utterance, + lengths=lengths, + right_context=right_context, + summary=summary, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) state = self._pack_state( next_key, next_val, utterance.size(0), memory, state ) @@ -718,20 +771,22 @@ class EmformerLayer(nn.Module): layer_norm_utterance, layer_norm_right_context, ) = self._apply_pre_attention_layer_norm(utterance, right_context) - output_right_context_utterance, output_memory = \ - self._apply_attention_forward( - layer_norm_utterance, - lengths, - layer_norm_right_context, - memory, - attention_mask, - ) - output_utterance, output_right_context = \ - self._apply_post_attention_ffn_layer_norm( - output_right_context_utterance, - utterance, - right_context - ) + ( + output_right_context_utterance, + output_memory, + ) = self._apply_attention_forward( + layer_norm_utterance, + lengths, + layer_norm_right_context, + memory, + attention_mask, + ) + ( + output_utterance, + output_right_context, + ) = self._apply_post_attention_ffn_layer_norm( + output_right_context_utterance, utterance, right_context + ) return output_utterance, output_right_context, output_memory @torch.jit.export @@ -745,63 +800,66 @@ class EmformerLayer(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: """Forward pass for inference. - 1) Apply layer normalization on input utterance and right context - before attention; - 2) Apply attention module with cached state, compute updated utterance, - right context, and memory, and update state; - 3) Apply feed forward module and layer normalization on output utterance - and right context. + 1) Apply layer normalization on input utterance and right context + before attention; + 2) Apply attention module with cached state, compute updated utterance, + right context, and memory, and update state; + 3) Apply feed forward module and layer normalization on output + utterance and right context. - B: batch size; - D: embedding dimension; - R: length of right_context; - U: length of utterance; - M: length of memory. + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. - Args: - utterance (torch.Tensor): - Utterance frames, with shape (U, B, D). - lengths (torch.Tensor): - With shape (B,) and i-th element representing - number of valid frames for i-th batch element in utterance. - right_context (torch.Tensor): - Right context frames, with shape (R, B, D). - memory (torch.Tensor): - Memory elements, with shape (M, B, D). - state (List[torch.Tensor], optional): - List of tensors representing layer internal state generated in - preceding computation. (default=None) + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing + number of valid frames for i-th batch element in utterance. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + memory (torch.Tensor): + Memory elements, with shape (M, B, D). + state (List[torch.Tensor], optional): + List of tensors representing layer internal state generated in + preceding computation. (default=None) - Returns: - (Tensor, Tensor, List[torch.Tensor], Tensor): - - output utterance, with shape (U, B, D); - - output right_context, with shape (R, B, D); - - output memory, with shape (1, B, D) or (0, B, D). - - output state. + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output memory, with shape (1, B, D) or (0, B, D). + - output state. """ ( layer_norm_utterance, layer_norm_right_context, ) = self._apply_pre_attention_layer_norm(utterance, right_context) - output_right_context_utterance, output_memory, output_state = \ - self._apply_attention_infer( - layer_norm_utterance, - lengths, - layer_norm_right_context, - memory, - state - ) - output_utterance, output_right_context = \ - self._apply_post_attention_ffn_layer_norm( - output_right_context_utterance, - utterance, - right_context - ) + ( + output_right_context_utterance, + output_memory, + output_state, + ) = self._apply_attention_infer( + layer_norm_utterance, + lengths, + layer_norm_right_context, + memory, + state, + ) + ( + output_utterance, + output_right_context, + ) = self._apply_post_attention_ffn_layer_norm( + output_right_context_utterance, utterance, right_context + ) return ( output_utterance, output_right_context, output_memory, - output_state + output_state, ) @@ -895,7 +953,7 @@ class EmformerEncoder(nn.Module): self.max_memory_size = max_memory_size def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: - """Hard copy each chunk's right context and concat them. """ + """Hard copy each chunk's right context and concat them.""" T = x.shape[0] num_segs = math.ceil( (T - self.right_context_length) / self.chunk_length @@ -905,7 +963,8 @@ class EmformerEncoder(nn.Module): start = (seg_idx + 1) * self.chunk_length end = start + self.right_context_length right_context_blocks.append(x[start:end]) - right_context_blocks.append(x[T - self.right_context_length:]) + last_right_context_start_idx = T - self.right_context_length + right_context_blocks.append(x[last_right_context_start_idx:]) return torch.cat(right_context_blocks) def _gen_attention_mask_col_widths( @@ -981,8 +1040,9 @@ class EmformerEncoder(nn.Module): num_cols = 9 # right context and utterance both attend to memory, right context, # utterance - right_context_utterance_cols_mask = \ - [idx in [1, 4, 7] for idx in range(num_cols)] + right_context_utterance_cols_mask = [ + idx in [1, 4, 7] for idx in range(num_cols) + ] # summary attends to right context, utterance summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] masks_to_concat = [right_context_mask, utterance_mask, summary_mask] @@ -990,8 +1050,9 @@ class EmformerEncoder(nn.Module): num_cols = 6 # right context and utterance both attend to right context and # utterance - right_context_utterance_cols_mask = \ - [idx in [1, 4] for idx in range(num_cols)] + right_context_utterance_cols_mask = [ + idx in [1, 4] for idx in range(num_cols) + ] summary_cols_mask = None masks_to_concat = [right_context_mask, utterance_mask] @@ -1002,7 +1063,7 @@ class EmformerEncoder(nn.Module): col_widths, right_context_utterance_cols_mask, self.right_context_length, - utterance.device + utterance.device, ) right_context_mask.append(right_context_mask_block) @@ -1053,13 +1114,13 @@ class EmformerEncoder(nn.Module): right_context at the end. """ right_context = self._gen_right_context(x) - utterance = x[:x.size(0) - self.right_context_length] + utterance = x[: x.size(0) - self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op( - utterance.permute(1, 2, 0) - ).permute(2, 0, 1)[:-1] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1159,13 +1220,9 @@ class Emformer(EncoderInterface): self.subsampling_factor = subsampling_factor self.right_context_length = right_context_length if subsampling_factor != 4: - raise NotImplementedError( - "Support only 'subsampling_factor=4'." - ) + raise NotImplementedError("Support only 'subsampling_factor=4'.") if chunk_length % 4 != 0: - raise NotImplementedError( - "chunk_length must be a mutiple of 4." - ) + raise NotImplementedError("chunk_length must be a mutiple of 4.") if left_context_length != 0 and left_context_length % 4 != 0: raise NotImplementedError( "left_context_length must be 0 or a mutiple of 4." @@ -1289,8 +1346,9 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths, output_states = \ - self.encoder.infer(x, x_lens, states) # (T, N, C) + output, output_lengths, output_states = self.encoder.infer( + x, x_lens, states + ) # (T, N, C) logits = self.encoder_output_layer(output) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)