mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update emformer.py
This commit is contained in:
parent
2d1b90f758
commit
d58002c414
@ -4,6 +4,7 @@ repos:
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--line-length=80]
|
||||
additional_dependencies: ['click==8.0.1']
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 3.9.2
|
||||
|
@ -1,3 +1,22 @@
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# It is modified based on
|
||||
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
@ -22,29 +41,32 @@ def _get_activation_module(activation: str) -> nn.Module:
|
||||
|
||||
|
||||
def _get_weight_init_gains(
|
||||
weight_init_scale_strategy: Optional[str],
|
||||
num_layers: int
|
||||
weight_init_scale_strategy: Optional[str], num_layers: int
|
||||
) -> List[Optional[float]]:
|
||||
if weight_init_scale_strategy is None:
|
||||
return [None for _ in range(num_layers)]
|
||||
elif weight_init_scale_strategy == "depthwise":
|
||||
return [1.0 / math.sqrt(layer_idx + 1)
|
||||
for layer_idx in range(num_layers)]
|
||||
return [
|
||||
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
|
||||
]
|
||||
elif weight_init_scale_strategy == "constant":
|
||||
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported weight_init_scale_strategy value"
|
||||
f"{weight_init_scale_strategy}")
|
||||
raise ValueError(
|
||||
f"Unsupported weight_init_scale_strategy value"
|
||||
f"{weight_init_scale_strategy}"
|
||||
)
|
||||
|
||||
|
||||
def _gen_attention_mask_block(
|
||||
col_widths: List[int],
|
||||
col_mask: List[bool],
|
||||
num_rows: int,
|
||||
device: torch.device
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
assert len(col_widths) == len(col_mask), (
|
||||
"Length of col_widths must match that of col_mask")
|
||||
assert len(col_widths) == len(
|
||||
col_mask
|
||||
), "Length of col_widths must match that of col_mask"
|
||||
|
||||
mask_block = [
|
||||
torch.ones(num_rows, col_width, device=device)
|
||||
@ -99,9 +121,7 @@ class EmformerAttention(nn.Module):
|
||||
|
||||
self.scaling = (self.embed_dim // self.nhead) ** -0.5
|
||||
|
||||
self.emb_to_key_value = nn.Linear(
|
||||
embed_dim, 2 * embed_dim, bias=True
|
||||
)
|
||||
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
||||
self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
@ -154,7 +174,7 @@ class EmformerAttention(nn.Module):
|
||||
)
|
||||
attention_weights_float = attention_weights_float.masked_fill(
|
||||
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
self.negative_inf
|
||||
self.negative_inf,
|
||||
)
|
||||
attention_weights_float = attention_weights_float.view(
|
||||
B * self.nhead, Q, -1
|
||||
@ -164,9 +184,7 @@ class EmformerAttention(nn.Module):
|
||||
attention_weights_float, dim=-1
|
||||
).type_as(attention_weights)
|
||||
attention_probs = nn.functional.dropout(
|
||||
attention_probs,
|
||||
p=float(self.dropout),
|
||||
training=self.training
|
||||
attention_probs, p=float(self.dropout), training=self.training
|
||||
)
|
||||
return attention_probs
|
||||
|
||||
@ -242,14 +260,28 @@ class EmformerAttention(nn.Module):
|
||||
# [mems, right context, left context, uttrance]
|
||||
M = memory.size(0)
|
||||
R = right_context.size(0)
|
||||
key = torch.cat([key[:M + R], left_context_key, key[M + R:]])
|
||||
value = torch.cat([value[:M + R], left_context_val, value[M + R:]])
|
||||
right_context_end_idx = M + R
|
||||
key = torch.cat(
|
||||
[
|
||||
key[:right_context_end_idx],
|
||||
left_context_key,
|
||||
key[right_context_end_idx:],
|
||||
]
|
||||
)
|
||||
value = torch.cat(
|
||||
[
|
||||
value[:right_context_end_idx],
|
||||
left_context_val,
|
||||
value[right_context_end_idx:],
|
||||
]
|
||||
)
|
||||
|
||||
# Compute attention weights from query, key, and value.
|
||||
reshaped_query, reshaped_key, reshaped_value = [
|
||||
tensor.contiguous().view(
|
||||
-1, B * self.nhead, self.embed_dim // self.nhead
|
||||
).transpose(0, 1) for tensor in [query, key, value]
|
||||
tensor.contiguous()
|
||||
.view(-1, B * self.nhead, self.embed_dim // self.nhead)
|
||||
.transpose(0, 1)
|
||||
for tensor in [query, key, value]
|
||||
]
|
||||
attention_weights = torch.bmm(
|
||||
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
|
||||
@ -272,18 +304,21 @@ class EmformerAttention(nn.Module):
|
||||
attention = torch.bmm(attention_probs, reshaped_value)
|
||||
Q = query.size(0)
|
||||
assert attention.shape == (
|
||||
B * self.nhead, Q, self.embed_dim // self.nhead,
|
||||
B * self.nhead,
|
||||
Q,
|
||||
self.embed_dim // self.nhead,
|
||||
)
|
||||
attention = attention.transpose(0, 1).contiguous().view(
|
||||
Q, B, self.embed_dim
|
||||
attention = (
|
||||
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||
)
|
||||
|
||||
# Apply output projection.
|
||||
outputs = self.out_proj(attention)
|
||||
|
||||
S = summary.size(0)
|
||||
output_right_context_utterance = outputs[:Q - S]
|
||||
output_memory = outputs[Q - S:]
|
||||
summary_start_idx = Q - S
|
||||
output_right_context_utterance = outputs[:summary_start_idx]
|
||||
output_memory = outputs[summary_start_idx:]
|
||||
if self.tanh_on_mem:
|
||||
output_memory = torch.tanh(output_memory)
|
||||
else:
|
||||
@ -331,14 +366,13 @@ class EmformerAttention(nn.Module):
|
||||
- output of right context and utterance, with shape (R + U, B, D).
|
||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||
"""
|
||||
output_right_context_utterance, output_memory, _, _ = \
|
||||
self._forward_impl(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
summary,
|
||||
memory,
|
||||
attention_mask
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
_,
|
||||
_,
|
||||
) = self._forward_impl(
|
||||
utterance, lengths, right_context, summary, memory, attention_mask
|
||||
)
|
||||
return output_right_context_utterance, output_memory[:-1]
|
||||
|
||||
@ -394,15 +428,23 @@ class EmformerAttention(nn.Module):
|
||||
# query: [right context, utterance, summary]
|
||||
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
|
||||
# key, value: [memory, right context, left context, uttrance]
|
||||
KV = memory.size(0) + right_context.size(0) + \
|
||||
left_context_key.size(0) + utterance.size(0)
|
||||
attention_mask = torch.zeros(
|
||||
Q, KV
|
||||
).to(dtype=torch.bool, device=utterance.device)
|
||||
KV = (
|
||||
memory.size(0)
|
||||
+ right_context.size(0)
|
||||
+ left_context_key.size(0)
|
||||
+ utterance.size(0)
|
||||
)
|
||||
attention_mask = torch.zeros(Q, KV).to(
|
||||
dtype=torch.bool, device=utterance.device
|
||||
)
|
||||
# Disallow attention bettween the summary vector with the memory bank
|
||||
attention_mask[-1, : memory.size(0)] = True
|
||||
output_right_context_utterance, output_memory, key, value = \
|
||||
self._forward_impl(
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key,
|
||||
value,
|
||||
) = self._forward_impl(
|
||||
utterance,
|
||||
lengths,
|
||||
right_context,
|
||||
@ -412,11 +454,12 @@ class EmformerAttention(nn.Module):
|
||||
left_context_key=left_context_key,
|
||||
left_context_val=left_context_val,
|
||||
)
|
||||
right_context_end_idx = memory.size(0) + right_context.size(0)
|
||||
return (
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
key[memory.size(0) + right_context.size(0):],
|
||||
value[memory.size(0) + right_context.size(0):],
|
||||
key[right_context_end_idx:],
|
||||
value[right_context_end_idx:],
|
||||
)
|
||||
|
||||
|
||||
@ -499,9 +542,7 @@ class EmformerLayer(nn.Module):
|
||||
self.use_memory = max_memory_size > 0
|
||||
|
||||
def _init_state(
|
||||
self,
|
||||
batch_size: int,
|
||||
device: Optional[torch.device]
|
||||
self, batch_size: int, device: Optional[torch.device]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Initialize states with zeros."""
|
||||
empty_memory = torch.zeros(
|
||||
@ -519,8 +560,7 @@ class EmformerLayer(nn.Module):
|
||||
return [empty_memory, left_context_key, left_context_val, past_length]
|
||||
|
||||
def _unpack_state(
|
||||
self,
|
||||
state: List[torch.Tensor]
|
||||
self, state: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Unpack cached states including:
|
||||
1) output memory from previous chunks in the lower layer;
|
||||
@ -532,11 +572,13 @@ class EmformerLayer(nn.Module):
|
||||
past_memory_length = min(
|
||||
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
||||
)
|
||||
pre_memory = state[0][self.max_memory_size - past_memory_length:]
|
||||
left_context_key = \
|
||||
state[1][self.left_context_length - past_left_context_length:]
|
||||
left_context_val = \
|
||||
state[2][self.left_context_length - past_left_context_length:]
|
||||
memory_start_idx = self.max_memory_size - past_memory_length
|
||||
pre_memory = state[0][memory_start_idx:]
|
||||
left_context_start_idx = (
|
||||
self.left_context_length - past_left_context_length
|
||||
)
|
||||
left_context_key = state[1][left_context_start_idx:]
|
||||
left_context_val = state[2][left_context_start_idx:]
|
||||
return pre_memory, left_context_key, left_context_val
|
||||
|
||||
def _pack_state(
|
||||
@ -556,9 +598,12 @@ class EmformerLayer(nn.Module):
|
||||
new_memory = torch.cat([state[0], memory])
|
||||
new_key = torch.cat([state[1], next_key])
|
||||
new_val = torch.cat([state[2], next_val])
|
||||
state[0] = new_memory[new_memory.size(0) - self.max_memory_size:]
|
||||
state[1] = new_key[new_key.size(0) - self.left_context_length:]
|
||||
state[2] = new_val[new_val.size(0) - self.left_context_length:]
|
||||
memory_start_idx = new_memory.size(0) - self.max_memory_size
|
||||
state[0] = new_memory[memory_start_idx:]
|
||||
key_start_idx = new_key.size(0) - self.left_context_length
|
||||
state[1] = new_key[key_start_idx:]
|
||||
val_start_idx = new_val.size(0) - self.left_context_length
|
||||
state[2] = new_val[val_start_idx:]
|
||||
state[3] = state[3] + update_length
|
||||
return state
|
||||
|
||||
@ -569,27 +614,30 @@ class EmformerLayer(nn.Module):
|
||||
layer_norm_input = self.layer_norm_input(
|
||||
torch.cat([right_context, utterance])
|
||||
)
|
||||
layer_norm_utterance = layer_norm_input[right_context.size(0):]
|
||||
layer_norm_right_context = layer_norm_input[:right_context.size(0)]
|
||||
right_context_end_idx = right_context.size(0)
|
||||
layer_norm_utterance = layer_norm_input[right_context_end_idx:]
|
||||
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
|
||||
return layer_norm_utterance, layer_norm_right_context
|
||||
|
||||
def _apply_post_attention_ffn_layer_norm(
|
||||
self,
|
||||
output_right_context_utterance: torch.Tensor,
|
||||
utterance: torch.Tensor,
|
||||
right_context: torch.Tensor
|
||||
right_context: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply feed forward and layer normalization after attention."""
|
||||
# Apply residual connection between input and attention output.
|
||||
result = self.dropout(output_right_context_utterance) + \
|
||||
torch.cat([right_context, utterance])
|
||||
result = self.dropout(output_right_context_utterance) + torch.cat(
|
||||
[right_context, utterance]
|
||||
)
|
||||
# Apply feedforward module and residual connection.
|
||||
result = self.pos_ff(result) + result
|
||||
# Apply layer normalization for output.
|
||||
result = self.layer_norm_output(result)
|
||||
|
||||
output_utterance = result[right_context.size(0):]
|
||||
output_right_context = result[:right_context.size(0)]
|
||||
right_context_end_idx = right_context.size(0)
|
||||
output_utterance = result[right_context_end_idx:]
|
||||
output_right_context = result[:right_context_end_idx]
|
||||
return output_utterance, output_right_context
|
||||
|
||||
def _apply_attention_forward(
|
||||
@ -607,9 +655,9 @@ class EmformerLayer(nn.Module):
|
||||
)
|
||||
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(
|
||||
utterance.permute(1, 2, 0)
|
||||
).permute(2, 0, 1)
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
)
|
||||
else:
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
@ -646,19 +694,24 @@ class EmformerLayer(nn.Module):
|
||||
"""
|
||||
if state is None:
|
||||
state = self._init_state(utterance.size(1), device=utterance.device)
|
||||
pre_memory, left_context_key, left_context_val = \
|
||||
self._unpack_state(state)
|
||||
pre_memory, left_context_key, left_context_val = self._unpack_state(
|
||||
state
|
||||
)
|
||||
if self.use_memory:
|
||||
summary = self.summary_op(
|
||||
utterance.permute(1, 2, 0)
|
||||
).permute(2, 0, 1)
|
||||
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
)
|
||||
summary = summary[:1]
|
||||
else:
|
||||
summary = torch.empty(0).to(
|
||||
dtype=utterance.dtype, device=utterance.device
|
||||
)
|
||||
output_right_context_utterance, output_memory, next_key, next_val = \
|
||||
self.attention.infer(
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
next_key,
|
||||
next_val,
|
||||
) = self.attention.infer(
|
||||
utterance=utterance,
|
||||
lengths=lengths,
|
||||
right_context=right_context,
|
||||
@ -718,19 +771,21 @@ class EmformerLayer(nn.Module):
|
||||
layer_norm_utterance,
|
||||
layer_norm_right_context,
|
||||
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
||||
output_right_context_utterance, output_memory = \
|
||||
self._apply_attention_forward(
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
) = self._apply_attention_forward(
|
||||
layer_norm_utterance,
|
||||
lengths,
|
||||
layer_norm_right_context,
|
||||
memory,
|
||||
attention_mask,
|
||||
)
|
||||
output_utterance, output_right_context = \
|
||||
self._apply_post_attention_ffn_layer_norm(
|
||||
output_right_context_utterance,
|
||||
utterance,
|
||||
right_context
|
||||
(
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
) = self._apply_post_attention_ffn_layer_norm(
|
||||
output_right_context_utterance, utterance, right_context
|
||||
)
|
||||
return output_utterance, output_right_context, output_memory
|
||||
|
||||
@ -749,8 +804,8 @@ class EmformerLayer(nn.Module):
|
||||
before attention;
|
||||
2) Apply attention module with cached state, compute updated utterance,
|
||||
right context, and memory, and update state;
|
||||
3) Apply feed forward module and layer normalization on output utterance
|
||||
and right context.
|
||||
3) Apply feed forward module and layer normalization on output
|
||||
utterance and right context.
|
||||
|
||||
B: batch size;
|
||||
D: embedding dimension;
|
||||
@ -783,25 +838,28 @@ class EmformerLayer(nn.Module):
|
||||
layer_norm_utterance,
|
||||
layer_norm_right_context,
|
||||
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
||||
output_right_context_utterance, output_memory, output_state = \
|
||||
self._apply_attention_infer(
|
||||
(
|
||||
output_right_context_utterance,
|
||||
output_memory,
|
||||
output_state,
|
||||
) = self._apply_attention_infer(
|
||||
layer_norm_utterance,
|
||||
lengths,
|
||||
layer_norm_right_context,
|
||||
memory,
|
||||
state
|
||||
state,
|
||||
)
|
||||
output_utterance, output_right_context = \
|
||||
self._apply_post_attention_ffn_layer_norm(
|
||||
output_right_context_utterance,
|
||||
utterance,
|
||||
right_context
|
||||
(
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
) = self._apply_post_attention_ffn_layer_norm(
|
||||
output_right_context_utterance, utterance, right_context
|
||||
)
|
||||
return (
|
||||
output_utterance,
|
||||
output_right_context,
|
||||
output_memory,
|
||||
output_state
|
||||
output_state,
|
||||
)
|
||||
|
||||
|
||||
@ -905,7 +963,8 @@ class EmformerEncoder(nn.Module):
|
||||
start = (seg_idx + 1) * self.chunk_length
|
||||
end = start + self.right_context_length
|
||||
right_context_blocks.append(x[start:end])
|
||||
right_context_blocks.append(x[T - self.right_context_length:])
|
||||
last_right_context_start_idx = T - self.right_context_length
|
||||
right_context_blocks.append(x[last_right_context_start_idx:])
|
||||
return torch.cat(right_context_blocks)
|
||||
|
||||
def _gen_attention_mask_col_widths(
|
||||
@ -981,8 +1040,9 @@ class EmformerEncoder(nn.Module):
|
||||
num_cols = 9
|
||||
# right context and utterance both attend to memory, right context,
|
||||
# utterance
|
||||
right_context_utterance_cols_mask = \
|
||||
[idx in [1, 4, 7] for idx in range(num_cols)]
|
||||
right_context_utterance_cols_mask = [
|
||||
idx in [1, 4, 7] for idx in range(num_cols)
|
||||
]
|
||||
# summary attends to right context, utterance
|
||||
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
||||
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
||||
@ -990,8 +1050,9 @@ class EmformerEncoder(nn.Module):
|
||||
num_cols = 6
|
||||
# right context and utterance both attend to right context and
|
||||
# utterance
|
||||
right_context_utterance_cols_mask = \
|
||||
[idx in [1, 4] for idx in range(num_cols)]
|
||||
right_context_utterance_cols_mask = [
|
||||
idx in [1, 4] for idx in range(num_cols)
|
||||
]
|
||||
summary_cols_mask = None
|
||||
masks_to_concat = [right_context_mask, utterance_mask]
|
||||
|
||||
@ -1002,7 +1063,7 @@ class EmformerEncoder(nn.Module):
|
||||
col_widths,
|
||||
right_context_utterance_cols_mask,
|
||||
self.right_context_length,
|
||||
utterance.device
|
||||
utterance.device,
|
||||
)
|
||||
right_context_mask.append(right_context_mask_block)
|
||||
|
||||
@ -1057,9 +1118,9 @@ class EmformerEncoder(nn.Module):
|
||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||
attention_mask = self._gen_attention_mask(utterance)
|
||||
memory = (
|
||||
self.init_memory_op(
|
||||
utterance.permute(1, 2, 0)
|
||||
).permute(2, 0, 1)[:-1]
|
||||
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||
:-1
|
||||
]
|
||||
if self.use_memory
|
||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||
)
|
||||
@ -1159,13 +1220,9 @@ class Emformer(EncoderInterface):
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.right_context_length = right_context_length
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError(
|
||||
"Support only 'subsampling_factor=4'."
|
||||
)
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
if chunk_length % 4 != 0:
|
||||
raise NotImplementedError(
|
||||
"chunk_length must be a mutiple of 4."
|
||||
)
|
||||
raise NotImplementedError("chunk_length must be a mutiple of 4.")
|
||||
if left_context_length != 0 and left_context_length % 4 != 0:
|
||||
raise NotImplementedError(
|
||||
"left_context_length must be 0 or a mutiple of 4."
|
||||
@ -1289,8 +1346,9 @@ class Emformer(EncoderInterface):
|
||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
output, output_lengths, output_states = \
|
||||
self.encoder.infer(x, x_lens, states) # (T, N, C)
|
||||
output, output_lengths, output_states = self.encoder.infer(
|
||||
x, x_lens, states
|
||||
) # (T, N, C)
|
||||
|
||||
logits = self.encoder_output_layer(output)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
Loading…
x
Reference in New Issue
Block a user