update emformer.py

This commit is contained in:
yaozengwei 2022-04-08 20:31:32 +08:00
parent 2d1b90f758
commit d58002c414
2 changed files with 233 additions and 174 deletions

View File

@ -4,6 +4,7 @@ repos:
hooks:
- id: black
args: [--line-length=80]
additional_dependencies: ['click==8.0.1']
- repo: https://github.com/PyCQA/flake8
rev: 3.9.2

View File

@ -1,3 +1,22 @@
# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# It is modified based on
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
import math
import warnings
from typing import List, Optional, Tuple
@ -22,29 +41,32 @@ def _get_activation_module(activation: str) -> nn.Module:
def _get_weight_init_gains(
weight_init_scale_strategy: Optional[str],
num_layers: int
weight_init_scale_strategy: Optional[str], num_layers: int
) -> List[Optional[float]]:
if weight_init_scale_strategy is None:
return [None for _ in range(num_layers)]
elif weight_init_scale_strategy == "depthwise":
return [1.0 / math.sqrt(layer_idx + 1)
for layer_idx in range(num_layers)]
return [
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
]
elif weight_init_scale_strategy == "constant":
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
else:
raise ValueError(f"Unsupported weight_init_scale_strategy value"
f"{weight_init_scale_strategy}")
raise ValueError(
f"Unsupported weight_init_scale_strategy value"
f"{weight_init_scale_strategy}"
)
def _gen_attention_mask_block(
col_widths: List[int],
col_mask: List[bool],
num_rows: int,
device: torch.device
device: torch.device,
) -> torch.Tensor:
assert len(col_widths) == len(col_mask), (
"Length of col_widths must match that of col_mask")
assert len(col_widths) == len(
col_mask
), "Length of col_widths must match that of col_mask"
mask_block = [
torch.ones(num_rows, col_width, device=device)
@ -99,9 +121,7 @@ class EmformerAttention(nn.Module):
self.scaling = (self.embed_dim // self.nhead) ** -0.5
self.emb_to_key_value = nn.Linear(
embed_dim, 2 * embed_dim, bias=True
)
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
@ -154,7 +174,7 @@ class EmformerAttention(nn.Module):
)
attention_weights_float = attention_weights_float.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
self.negative_inf
self.negative_inf,
)
attention_weights_float = attention_weights_float.view(
B * self.nhead, Q, -1
@ -164,9 +184,7 @@ class EmformerAttention(nn.Module):
attention_weights_float, dim=-1
).type_as(attention_weights)
attention_probs = nn.functional.dropout(
attention_probs,
p=float(self.dropout),
training=self.training
attention_probs, p=float(self.dropout), training=self.training
)
return attention_probs
@ -242,14 +260,28 @@ class EmformerAttention(nn.Module):
# [mems, right context, left context, uttrance]
M = memory.size(0)
R = right_context.size(0)
key = torch.cat([key[:M + R], left_context_key, key[M + R:]])
value = torch.cat([value[:M + R], left_context_val, value[M + R:]])
right_context_end_idx = M + R
key = torch.cat(
[
key[:right_context_end_idx],
left_context_key,
key[right_context_end_idx:],
]
)
value = torch.cat(
[
value[:right_context_end_idx],
left_context_val,
value[right_context_end_idx:],
]
)
# Compute attention weights from query, key, and value.
reshaped_query, reshaped_key, reshaped_value = [
tensor.contiguous().view(
-1, B * self.nhead, self.embed_dim // self.nhead
).transpose(0, 1) for tensor in [query, key, value]
tensor.contiguous()
.view(-1, B * self.nhead, self.embed_dim // self.nhead)
.transpose(0, 1)
for tensor in [query, key, value]
]
attention_weights = torch.bmm(
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
@ -272,18 +304,21 @@ class EmformerAttention(nn.Module):
attention = torch.bmm(attention_probs, reshaped_value)
Q = query.size(0)
assert attention.shape == (
B * self.nhead, Q, self.embed_dim // self.nhead,
B * self.nhead,
Q,
self.embed_dim // self.nhead,
)
attention = attention.transpose(0, 1).contiguous().view(
Q, B, self.embed_dim
attention = (
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
)
# Apply output projection.
outputs = self.out_proj(attention)
S = summary.size(0)
output_right_context_utterance = outputs[:Q - S]
output_memory = outputs[Q - S:]
summary_start_idx = Q - S
output_right_context_utterance = outputs[:summary_start_idx]
output_memory = outputs[summary_start_idx:]
if self.tanh_on_mem:
output_memory = torch.tanh(output_memory)
else:
@ -331,14 +366,13 @@ class EmformerAttention(nn.Module):
- output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
"""
output_right_context_utterance, output_memory, _, _ = \
self._forward_impl(
utterance,
lengths,
right_context,
summary,
memory,
attention_mask
(
output_right_context_utterance,
output_memory,
_,
_,
) = self._forward_impl(
utterance, lengths, right_context, summary, memory, attention_mask
)
return output_right_context_utterance, output_memory[:-1]
@ -394,15 +428,23 @@ class EmformerAttention(nn.Module):
# query: [right context, utterance, summary]
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
# key, value: [memory, right context, left context, uttrance]
KV = memory.size(0) + right_context.size(0) + \
left_context_key.size(0) + utterance.size(0)
attention_mask = torch.zeros(
Q, KV
).to(dtype=torch.bool, device=utterance.device)
KV = (
memory.size(0)
+ right_context.size(0)
+ left_context_key.size(0)
+ utterance.size(0)
)
attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device
)
# Disallow attention bettween the summary vector with the memory bank
attention_mask[-1, : memory.size(0)] = True
output_right_context_utterance, output_memory, key, value = \
self._forward_impl(
(
output_right_context_utterance,
output_memory,
key,
value,
) = self._forward_impl(
utterance,
lengths,
right_context,
@ -412,11 +454,12 @@ class EmformerAttention(nn.Module):
left_context_key=left_context_key,
left_context_val=left_context_val,
)
right_context_end_idx = memory.size(0) + right_context.size(0)
return (
output_right_context_utterance,
output_memory,
key[memory.size(0) + right_context.size(0):],
value[memory.size(0) + right_context.size(0):],
key[right_context_end_idx:],
value[right_context_end_idx:],
)
@ -499,9 +542,7 @@ class EmformerLayer(nn.Module):
self.use_memory = max_memory_size > 0
def _init_state(
self,
batch_size: int,
device: Optional[torch.device]
self, batch_size: int, device: Optional[torch.device]
) -> List[torch.Tensor]:
"""Initialize states with zeros."""
empty_memory = torch.zeros(
@ -519,8 +560,7 @@ class EmformerLayer(nn.Module):
return [empty_memory, left_context_key, left_context_val, past_length]
def _unpack_state(
self,
state: List[torch.Tensor]
self, state: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Unpack cached states including:
1) output memory from previous chunks in the lower layer;
@ -532,11 +572,13 @@ class EmformerLayer(nn.Module):
past_memory_length = min(
self.max_memory_size, math.ceil(past_length / self.chunk_length)
)
pre_memory = state[0][self.max_memory_size - past_memory_length:]
left_context_key = \
state[1][self.left_context_length - past_left_context_length:]
left_context_val = \
state[2][self.left_context_length - past_left_context_length:]
memory_start_idx = self.max_memory_size - past_memory_length
pre_memory = state[0][memory_start_idx:]
left_context_start_idx = (
self.left_context_length - past_left_context_length
)
left_context_key = state[1][left_context_start_idx:]
left_context_val = state[2][left_context_start_idx:]
return pre_memory, left_context_key, left_context_val
def _pack_state(
@ -556,9 +598,12 @@ class EmformerLayer(nn.Module):
new_memory = torch.cat([state[0], memory])
new_key = torch.cat([state[1], next_key])
new_val = torch.cat([state[2], next_val])
state[0] = new_memory[new_memory.size(0) - self.max_memory_size:]
state[1] = new_key[new_key.size(0) - self.left_context_length:]
state[2] = new_val[new_val.size(0) - self.left_context_length:]
memory_start_idx = new_memory.size(0) - self.max_memory_size
state[0] = new_memory[memory_start_idx:]
key_start_idx = new_key.size(0) - self.left_context_length
state[1] = new_key[key_start_idx:]
val_start_idx = new_val.size(0) - self.left_context_length
state[2] = new_val[val_start_idx:]
state[3] = state[3] + update_length
return state
@ -569,27 +614,30 @@ class EmformerLayer(nn.Module):
layer_norm_input = self.layer_norm_input(
torch.cat([right_context, utterance])
)
layer_norm_utterance = layer_norm_input[right_context.size(0):]
layer_norm_right_context = layer_norm_input[:right_context.size(0)]
right_context_end_idx = right_context.size(0)
layer_norm_utterance = layer_norm_input[right_context_end_idx:]
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
return layer_norm_utterance, layer_norm_right_context
def _apply_post_attention_ffn_layer_norm(
self,
output_right_context_utterance: torch.Tensor,
utterance: torch.Tensor,
right_context: torch.Tensor
right_context: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply feed forward and layer normalization after attention."""
# Apply residual connection between input and attention output.
result = self.dropout(output_right_context_utterance) + \
torch.cat([right_context, utterance])
result = self.dropout(output_right_context_utterance) + torch.cat(
[right_context, utterance]
)
# Apply feedforward module and residual connection.
result = self.pos_ff(result) + result
# Apply layer normalization for output.
result = self.layer_norm_output(result)
output_utterance = result[right_context.size(0):]
output_right_context = result[:right_context.size(0)]
right_context_end_idx = right_context.size(0)
output_utterance = result[right_context_end_idx:]
output_right_context = result[:right_context_end_idx]
return output_utterance, output_right_context
def _apply_attention_forward(
@ -607,9 +655,9 @@ class EmformerLayer(nn.Module):
)
if self.use_memory:
summary = self.summary_op(
utterance.permute(1, 2, 0)
).permute(2, 0, 1)
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1
)
else:
summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
@ -646,19 +694,24 @@ class EmformerLayer(nn.Module):
"""
if state is None:
state = self._init_state(utterance.size(1), device=utterance.device)
pre_memory, left_context_key, left_context_val = \
self._unpack_state(state)
pre_memory, left_context_key, left_context_val = self._unpack_state(
state
)
if self.use_memory:
summary = self.summary_op(
utterance.permute(1, 2, 0)
).permute(2, 0, 1)
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
2, 0, 1
)
summary = summary[:1]
else:
summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device
)
output_right_context_utterance, output_memory, next_key, next_val = \
self.attention.infer(
(
output_right_context_utterance,
output_memory,
next_key,
next_val,
) = self.attention.infer(
utterance=utterance,
lengths=lengths,
right_context=right_context,
@ -718,19 +771,21 @@ class EmformerLayer(nn.Module):
layer_norm_utterance,
layer_norm_right_context,
) = self._apply_pre_attention_layer_norm(utterance, right_context)
output_right_context_utterance, output_memory = \
self._apply_attention_forward(
(
output_right_context_utterance,
output_memory,
) = self._apply_attention_forward(
layer_norm_utterance,
lengths,
layer_norm_right_context,
memory,
attention_mask,
)
output_utterance, output_right_context = \
self._apply_post_attention_ffn_layer_norm(
output_right_context_utterance,
utterance,
right_context
(
output_utterance,
output_right_context,
) = self._apply_post_attention_ffn_layer_norm(
output_right_context_utterance, utterance, right_context
)
return output_utterance, output_right_context, output_memory
@ -749,8 +804,8 @@ class EmformerLayer(nn.Module):
before attention;
2) Apply attention module with cached state, compute updated utterance,
right context, and memory, and update state;
3) Apply feed forward module and layer normalization on output utterance
and right context.
3) Apply feed forward module and layer normalization on output
utterance and right context.
B: batch size;
D: embedding dimension;
@ -783,25 +838,28 @@ class EmformerLayer(nn.Module):
layer_norm_utterance,
layer_norm_right_context,
) = self._apply_pre_attention_layer_norm(utterance, right_context)
output_right_context_utterance, output_memory, output_state = \
self._apply_attention_infer(
(
output_right_context_utterance,
output_memory,
output_state,
) = self._apply_attention_infer(
layer_norm_utterance,
lengths,
layer_norm_right_context,
memory,
state
state,
)
output_utterance, output_right_context = \
self._apply_post_attention_ffn_layer_norm(
output_right_context_utterance,
utterance,
right_context
(
output_utterance,
output_right_context,
) = self._apply_post_attention_ffn_layer_norm(
output_right_context_utterance, utterance, right_context
)
return (
output_utterance,
output_right_context,
output_memory,
output_state
output_state,
)
@ -905,7 +963,8 @@ class EmformerEncoder(nn.Module):
start = (seg_idx + 1) * self.chunk_length
end = start + self.right_context_length
right_context_blocks.append(x[start:end])
right_context_blocks.append(x[T - self.right_context_length:])
last_right_context_start_idx = T - self.right_context_length
right_context_blocks.append(x[last_right_context_start_idx:])
return torch.cat(right_context_blocks)
def _gen_attention_mask_col_widths(
@ -981,8 +1040,9 @@ class EmformerEncoder(nn.Module):
num_cols = 9
# right context and utterance both attend to memory, right context,
# utterance
right_context_utterance_cols_mask = \
[idx in [1, 4, 7] for idx in range(num_cols)]
right_context_utterance_cols_mask = [
idx in [1, 4, 7] for idx in range(num_cols)
]
# summary attends to right context, utterance
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
@ -990,8 +1050,9 @@ class EmformerEncoder(nn.Module):
num_cols = 6
# right context and utterance both attend to right context and
# utterance
right_context_utterance_cols_mask = \
[idx in [1, 4] for idx in range(num_cols)]
right_context_utterance_cols_mask = [
idx in [1, 4] for idx in range(num_cols)
]
summary_cols_mask = None
masks_to_concat = [right_context_mask, utterance_mask]
@ -1002,7 +1063,7 @@ class EmformerEncoder(nn.Module):
col_widths,
right_context_utterance_cols_mask,
self.right_context_length,
utterance.device
utterance.device,
)
right_context_mask.append(right_context_mask_block)
@ -1057,9 +1118,9 @@ class EmformerEncoder(nn.Module):
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance)
memory = (
self.init_memory_op(
utterance.permute(1, 2, 0)
).permute(2, 0, 1)[:-1]
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
:-1
]
if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device)
)
@ -1159,13 +1220,9 @@ class Emformer(EncoderInterface):
self.subsampling_factor = subsampling_factor
self.right_context_length = right_context_length
if subsampling_factor != 4:
raise NotImplementedError(
"Support only 'subsampling_factor=4'."
)
raise NotImplementedError("Support only 'subsampling_factor=4'.")
if chunk_length % 4 != 0:
raise NotImplementedError(
"chunk_length must be a mutiple of 4."
)
raise NotImplementedError("chunk_length must be a mutiple of 4.")
if left_context_length != 0 and left_context_length % 4 != 0:
raise NotImplementedError(
"left_context_length must be 0 or a mutiple of 4."
@ -1289,8 +1346,9 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item()
output, output_lengths, output_states = \
self.encoder.infer(x, x_lens, states) # (T, N, C)
output, output_lengths, output_states = self.encoder.infer(
x, x_lens, states
) # (T, N, C)
logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)