mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
update emformer.py
This commit is contained in:
parent
2d1b90f758
commit
d58002c414
@ -4,6 +4,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
args: [--line-length=80]
|
args: [--line-length=80]
|
||||||
|
additional_dependencies: ['click==8.0.1']
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 3.9.2
|
rev: 3.9.2
|
||||||
|
@ -1,3 +1,22 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# It is modified based on
|
||||||
|
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@ -22,29 +41,32 @@ def _get_activation_module(activation: str) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def _get_weight_init_gains(
|
def _get_weight_init_gains(
|
||||||
weight_init_scale_strategy: Optional[str],
|
weight_init_scale_strategy: Optional[str], num_layers: int
|
||||||
num_layers: int
|
|
||||||
) -> List[Optional[float]]:
|
) -> List[Optional[float]]:
|
||||||
if weight_init_scale_strategy is None:
|
if weight_init_scale_strategy is None:
|
||||||
return [None for _ in range(num_layers)]
|
return [None for _ in range(num_layers)]
|
||||||
elif weight_init_scale_strategy == "depthwise":
|
elif weight_init_scale_strategy == "depthwise":
|
||||||
return [1.0 / math.sqrt(layer_idx + 1)
|
return [
|
||||||
for layer_idx in range(num_layers)]
|
1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
|
||||||
|
]
|
||||||
elif weight_init_scale_strategy == "constant":
|
elif weight_init_scale_strategy == "constant":
|
||||||
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported weight_init_scale_strategy value"
|
raise ValueError(
|
||||||
f"{weight_init_scale_strategy}")
|
f"Unsupported weight_init_scale_strategy value"
|
||||||
|
f"{weight_init_scale_strategy}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _gen_attention_mask_block(
|
def _gen_attention_mask_block(
|
||||||
col_widths: List[int],
|
col_widths: List[int],
|
||||||
col_mask: List[bool],
|
col_mask: List[bool],
|
||||||
num_rows: int,
|
num_rows: int,
|
||||||
device: torch.device
|
device: torch.device,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert len(col_widths) == len(col_mask), (
|
assert len(col_widths) == len(
|
||||||
"Length of col_widths must match that of col_mask")
|
col_mask
|
||||||
|
), "Length of col_widths must match that of col_mask"
|
||||||
|
|
||||||
mask_block = [
|
mask_block = [
|
||||||
torch.ones(num_rows, col_width, device=device)
|
torch.ones(num_rows, col_width, device=device)
|
||||||
@ -99,9 +121,7 @@ class EmformerAttention(nn.Module):
|
|||||||
|
|
||||||
self.scaling = (self.embed_dim // self.nhead) ** -0.5
|
self.scaling = (self.embed_dim // self.nhead) ** -0.5
|
||||||
|
|
||||||
self.emb_to_key_value = nn.Linear(
|
self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
|
||||||
embed_dim, 2 * embed_dim, bias=True
|
|
||||||
)
|
|
||||||
self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True)
|
self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
@ -119,7 +139,7 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
padding_mask: Optional[torch.Tensor],
|
padding_mask: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
""" Given the entire attention weights, mask out unecessary connections
|
"""Given the entire attention weights, mask out unecessary connections
|
||||||
and optionally with padding positions, to obtain underlying chunk-wise
|
and optionally with padding positions, to obtain underlying chunk-wise
|
||||||
attention probabilities.
|
attention probabilities.
|
||||||
|
|
||||||
@ -154,7 +174,7 @@ class EmformerAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
attention_weights_float = attention_weights_float.masked_fill(
|
attention_weights_float = attention_weights_float.masked_fill(
|
||||||
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
self.negative_inf
|
self.negative_inf,
|
||||||
)
|
)
|
||||||
attention_weights_float = attention_weights_float.view(
|
attention_weights_float = attention_weights_float.view(
|
||||||
B * self.nhead, Q, -1
|
B * self.nhead, Q, -1
|
||||||
@ -164,9 +184,7 @@ class EmformerAttention(nn.Module):
|
|||||||
attention_weights_float, dim=-1
|
attention_weights_float, dim=-1
|
||||||
).type_as(attention_weights)
|
).type_as(attention_weights)
|
||||||
attention_probs = nn.functional.dropout(
|
attention_probs = nn.functional.dropout(
|
||||||
attention_probs,
|
attention_probs, p=float(self.dropout), training=self.training
|
||||||
p=float(self.dropout),
|
|
||||||
training=self.training
|
|
||||||
)
|
)
|
||||||
return attention_probs
|
return attention_probs
|
||||||
|
|
||||||
@ -181,7 +199,7 @@ class EmformerAttention(nn.Module):
|
|||||||
left_context_key: Optional[torch.Tensor] = None,
|
left_context_key: Optional[torch.Tensor] = None,
|
||||||
left_context_val: Optional[torch.Tensor] = None,
|
left_context_val: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
""" Underlying chunk-wise attention implementation.
|
"""Underlying chunk-wise attention implementation.
|
||||||
|
|
||||||
L: length of left_context;
|
L: length of left_context;
|
||||||
S: length of summary;
|
S: length of summary;
|
||||||
@ -242,14 +260,28 @@ class EmformerAttention(nn.Module):
|
|||||||
# [mems, right context, left context, uttrance]
|
# [mems, right context, left context, uttrance]
|
||||||
M = memory.size(0)
|
M = memory.size(0)
|
||||||
R = right_context.size(0)
|
R = right_context.size(0)
|
||||||
key = torch.cat([key[:M + R], left_context_key, key[M + R:]])
|
right_context_end_idx = M + R
|
||||||
value = torch.cat([value[:M + R], left_context_val, value[M + R:]])
|
key = torch.cat(
|
||||||
|
[
|
||||||
|
key[:right_context_end_idx],
|
||||||
|
left_context_key,
|
||||||
|
key[right_context_end_idx:],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
value = torch.cat(
|
||||||
|
[
|
||||||
|
value[:right_context_end_idx],
|
||||||
|
left_context_val,
|
||||||
|
value[right_context_end_idx:],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Compute attention weights from query, key, and value.
|
# Compute attention weights from query, key, and value.
|
||||||
reshaped_query, reshaped_key, reshaped_value = [
|
reshaped_query, reshaped_key, reshaped_value = [
|
||||||
tensor.contiguous().view(
|
tensor.contiguous()
|
||||||
-1, B * self.nhead, self.embed_dim // self.nhead
|
.view(-1, B * self.nhead, self.embed_dim // self.nhead)
|
||||||
).transpose(0, 1) for tensor in [query, key, value]
|
.transpose(0, 1)
|
||||||
|
for tensor in [query, key, value]
|
||||||
]
|
]
|
||||||
attention_weights = torch.bmm(
|
attention_weights = torch.bmm(
|
||||||
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
|
reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
|
||||||
@ -272,18 +304,21 @@ class EmformerAttention(nn.Module):
|
|||||||
attention = torch.bmm(attention_probs, reshaped_value)
|
attention = torch.bmm(attention_probs, reshaped_value)
|
||||||
Q = query.size(0)
|
Q = query.size(0)
|
||||||
assert attention.shape == (
|
assert attention.shape == (
|
||||||
B * self.nhead, Q, self.embed_dim // self.nhead,
|
B * self.nhead,
|
||||||
|
Q,
|
||||||
|
self.embed_dim // self.nhead,
|
||||||
)
|
)
|
||||||
attention = attention.transpose(0, 1).contiguous().view(
|
attention = (
|
||||||
Q, B, self.embed_dim
|
attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply output projection.
|
# Apply output projection.
|
||||||
outputs = self.out_proj(attention)
|
outputs = self.out_proj(attention)
|
||||||
|
|
||||||
S = summary.size(0)
|
S = summary.size(0)
|
||||||
output_right_context_utterance = outputs[:Q - S]
|
summary_start_idx = Q - S
|
||||||
output_memory = outputs[Q - S:]
|
output_right_context_utterance = outputs[:summary_start_idx]
|
||||||
|
output_memory = outputs[summary_start_idx:]
|
||||||
if self.tanh_on_mem:
|
if self.tanh_on_mem:
|
||||||
output_memory = torch.tanh(output_memory)
|
output_memory = torch.tanh(output_memory)
|
||||||
else:
|
else:
|
||||||
@ -331,15 +366,14 @@ class EmformerAttention(nn.Module):
|
|||||||
- output of right context and utterance, with shape (R + U, B, D).
|
- output of right context and utterance, with shape (R + U, B, D).
|
||||||
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
- memory output, with shape (M, B, D), where M = S - 1 or M = 0.
|
||||||
"""
|
"""
|
||||||
output_right_context_utterance, output_memory, _, _ = \
|
(
|
||||||
self._forward_impl(
|
output_right_context_utterance,
|
||||||
utterance,
|
output_memory,
|
||||||
lengths,
|
_,
|
||||||
right_context,
|
_,
|
||||||
summary,
|
) = self._forward_impl(
|
||||||
memory,
|
utterance, lengths, right_context, summary, memory, attention_mask
|
||||||
attention_mask
|
)
|
||||||
)
|
|
||||||
return output_right_context_utterance, output_memory[:-1]
|
return output_right_context_utterance, output_memory[:-1]
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
@ -394,29 +428,38 @@ class EmformerAttention(nn.Module):
|
|||||||
# query: [right context, utterance, summary]
|
# query: [right context, utterance, summary]
|
||||||
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
|
Q = right_context.size(0) + utterance.size(0) + summary.size(0)
|
||||||
# key, value: [memory, right context, left context, uttrance]
|
# key, value: [memory, right context, left context, uttrance]
|
||||||
KV = memory.size(0) + right_context.size(0) + \
|
KV = (
|
||||||
left_context_key.size(0) + utterance.size(0)
|
memory.size(0)
|
||||||
attention_mask = torch.zeros(
|
+ right_context.size(0)
|
||||||
Q, KV
|
+ left_context_key.size(0)
|
||||||
).to(dtype=torch.bool, device=utterance.device)
|
+ utterance.size(0)
|
||||||
|
)
|
||||||
|
attention_mask = torch.zeros(Q, KV).to(
|
||||||
|
dtype=torch.bool, device=utterance.device
|
||||||
|
)
|
||||||
# Disallow attention bettween the summary vector with the memory bank
|
# Disallow attention bettween the summary vector with the memory bank
|
||||||
attention_mask[-1, :memory.size(0)] = True
|
attention_mask[-1, : memory.size(0)] = True
|
||||||
output_right_context_utterance, output_memory, key, value = \
|
(
|
||||||
self._forward_impl(
|
output_right_context_utterance,
|
||||||
utterance,
|
output_memory,
|
||||||
lengths,
|
key,
|
||||||
right_context,
|
value,
|
||||||
summary,
|
) = self._forward_impl(
|
||||||
memory,
|
utterance,
|
||||||
attention_mask,
|
lengths,
|
||||||
left_context_key=left_context_key,
|
right_context,
|
||||||
left_context_val=left_context_val,
|
summary,
|
||||||
)
|
memory,
|
||||||
|
attention_mask,
|
||||||
|
left_context_key=left_context_key,
|
||||||
|
left_context_val=left_context_val,
|
||||||
|
)
|
||||||
|
right_context_end_idx = memory.size(0) + right_context.size(0)
|
||||||
return (
|
return (
|
||||||
output_right_context_utterance,
|
output_right_context_utterance,
|
||||||
output_memory,
|
output_memory,
|
||||||
key[memory.size(0) + right_context.size(0):],
|
key[right_context_end_idx:],
|
||||||
value[memory.size(0) + right_context.size(0):],
|
value[right_context_end_idx:],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -499,9 +542,7 @@ class EmformerLayer(nn.Module):
|
|||||||
self.use_memory = max_memory_size > 0
|
self.use_memory = max_memory_size > 0
|
||||||
|
|
||||||
def _init_state(
|
def _init_state(
|
||||||
self,
|
self, batch_size: int, device: Optional[torch.device]
|
||||||
batch_size: int,
|
|
||||||
device: Optional[torch.device]
|
|
||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
"""Initialize states with zeros."""
|
"""Initialize states with zeros."""
|
||||||
empty_memory = torch.zeros(
|
empty_memory = torch.zeros(
|
||||||
@ -519,8 +560,7 @@ class EmformerLayer(nn.Module):
|
|||||||
return [empty_memory, left_context_key, left_context_val, past_length]
|
return [empty_memory, left_context_key, left_context_val, past_length]
|
||||||
|
|
||||||
def _unpack_state(
|
def _unpack_state(
|
||||||
self,
|
self, state: List[torch.Tensor]
|
||||||
state: List[torch.Tensor]
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""Unpack cached states including:
|
"""Unpack cached states including:
|
||||||
1) output memory from previous chunks in the lower layer;
|
1) output memory from previous chunks in the lower layer;
|
||||||
@ -532,11 +572,13 @@ class EmformerLayer(nn.Module):
|
|||||||
past_memory_length = min(
|
past_memory_length = min(
|
||||||
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
self.max_memory_size, math.ceil(past_length / self.chunk_length)
|
||||||
)
|
)
|
||||||
pre_memory = state[0][self.max_memory_size - past_memory_length:]
|
memory_start_idx = self.max_memory_size - past_memory_length
|
||||||
left_context_key = \
|
pre_memory = state[0][memory_start_idx:]
|
||||||
state[1][self.left_context_length - past_left_context_length:]
|
left_context_start_idx = (
|
||||||
left_context_val = \
|
self.left_context_length - past_left_context_length
|
||||||
state[2][self.left_context_length - past_left_context_length:]
|
)
|
||||||
|
left_context_key = state[1][left_context_start_idx:]
|
||||||
|
left_context_val = state[2][left_context_start_idx:]
|
||||||
return pre_memory, left_context_key, left_context_val
|
return pre_memory, left_context_key, left_context_val
|
||||||
|
|
||||||
def _pack_state(
|
def _pack_state(
|
||||||
@ -556,40 +598,46 @@ class EmformerLayer(nn.Module):
|
|||||||
new_memory = torch.cat([state[0], memory])
|
new_memory = torch.cat([state[0], memory])
|
||||||
new_key = torch.cat([state[1], next_key])
|
new_key = torch.cat([state[1], next_key])
|
||||||
new_val = torch.cat([state[2], next_val])
|
new_val = torch.cat([state[2], next_val])
|
||||||
state[0] = new_memory[new_memory.size(0) - self.max_memory_size:]
|
memory_start_idx = new_memory.size(0) - self.max_memory_size
|
||||||
state[1] = new_key[new_key.size(0) - self.left_context_length:]
|
state[0] = new_memory[memory_start_idx:]
|
||||||
state[2] = new_val[new_val.size(0) - self.left_context_length:]
|
key_start_idx = new_key.size(0) - self.left_context_length
|
||||||
|
state[1] = new_key[key_start_idx:]
|
||||||
|
val_start_idx = new_val.size(0) - self.left_context_length
|
||||||
|
state[2] = new_val[val_start_idx:]
|
||||||
state[3] = state[3] + update_length
|
state[3] = state[3] + update_length
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _apply_pre_attention_layer_norm(
|
def _apply_pre_attention_layer_norm(
|
||||||
self, utterance: torch.Tensor, right_context: torch.Tensor
|
self, utterance: torch.Tensor, right_context: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Apply layer normalization before attention. """
|
"""Apply layer normalization before attention."""
|
||||||
layer_norm_input = self.layer_norm_input(
|
layer_norm_input = self.layer_norm_input(
|
||||||
torch.cat([right_context, utterance])
|
torch.cat([right_context, utterance])
|
||||||
)
|
)
|
||||||
layer_norm_utterance = layer_norm_input[right_context.size(0):]
|
right_context_end_idx = right_context.size(0)
|
||||||
layer_norm_right_context = layer_norm_input[:right_context.size(0)]
|
layer_norm_utterance = layer_norm_input[right_context_end_idx:]
|
||||||
|
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
|
||||||
return layer_norm_utterance, layer_norm_right_context
|
return layer_norm_utterance, layer_norm_right_context
|
||||||
|
|
||||||
def _apply_post_attention_ffn_layer_norm(
|
def _apply_post_attention_ffn_layer_norm(
|
||||||
self,
|
self,
|
||||||
output_right_context_utterance: torch.Tensor,
|
output_right_context_utterance: torch.Tensor,
|
||||||
utterance: torch.Tensor,
|
utterance: torch.Tensor,
|
||||||
right_context: torch.Tensor
|
right_context: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Apply feed forward and layer normalization after attention."""
|
"""Apply feed forward and layer normalization after attention."""
|
||||||
# Apply residual connection between input and attention output.
|
# Apply residual connection between input and attention output.
|
||||||
result = self.dropout(output_right_context_utterance) + \
|
result = self.dropout(output_right_context_utterance) + torch.cat(
|
||||||
torch.cat([right_context, utterance])
|
[right_context, utterance]
|
||||||
|
)
|
||||||
# Apply feedforward module and residual connection.
|
# Apply feedforward module and residual connection.
|
||||||
result = self.pos_ff(result) + result
|
result = self.pos_ff(result) + result
|
||||||
# Apply layer normalization for output.
|
# Apply layer normalization for output.
|
||||||
result = self.layer_norm_output(result)
|
result = self.layer_norm_output(result)
|
||||||
|
|
||||||
output_utterance = result[right_context.size(0):]
|
right_context_end_idx = right_context.size(0)
|
||||||
output_right_context = result[:right_context.size(0)]
|
output_utterance = result[right_context_end_idx:]
|
||||||
|
output_right_context = result[:right_context_end_idx]
|
||||||
return output_utterance, output_right_context
|
return output_utterance, output_right_context
|
||||||
|
|
||||||
def _apply_attention_forward(
|
def _apply_attention_forward(
|
||||||
@ -600,16 +648,16 @@ class EmformerLayer(nn.Module):
|
|||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Apply attention in non-infer mode. """
|
"""Apply attention in non-infer mode."""
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"attention_mask must be not None in non-infer mode. "
|
"attention_mask must be not None in non-infer mode. "
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
summary = self.summary_op(
|
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
utterance.permute(1, 2, 0)
|
2, 0, 1
|
||||||
).permute(2, 0, 1)
|
)
|
||||||
else:
|
else:
|
||||||
summary = torch.empty(0).to(
|
summary = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
@ -646,27 +694,32 @@ class EmformerLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if state is None:
|
if state is None:
|
||||||
state = self._init_state(utterance.size(1), device=utterance.device)
|
state = self._init_state(utterance.size(1), device=utterance.device)
|
||||||
pre_memory, left_context_key, left_context_val = \
|
pre_memory, left_context_key, left_context_val = self._unpack_state(
|
||||||
self._unpack_state(state)
|
state
|
||||||
|
)
|
||||||
if self.use_memory:
|
if self.use_memory:
|
||||||
summary = self.summary_op(
|
summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
|
||||||
utterance.permute(1, 2, 0)
|
2, 0, 1
|
||||||
).permute(2, 0, 1)
|
)
|
||||||
summary = summary[:1]
|
summary = summary[:1]
|
||||||
else:
|
else:
|
||||||
summary = torch.empty(0).to(
|
summary = torch.empty(0).to(
|
||||||
dtype=utterance.dtype, device=utterance.device
|
dtype=utterance.dtype, device=utterance.device
|
||||||
)
|
)
|
||||||
output_right_context_utterance, output_memory, next_key, next_val = \
|
(
|
||||||
self.attention.infer(
|
output_right_context_utterance,
|
||||||
utterance=utterance,
|
output_memory,
|
||||||
lengths=lengths,
|
next_key,
|
||||||
right_context=right_context,
|
next_val,
|
||||||
summary=summary,
|
) = self.attention.infer(
|
||||||
memory=pre_memory,
|
utterance=utterance,
|
||||||
left_context_key=left_context_key,
|
lengths=lengths,
|
||||||
left_context_val=left_context_val,
|
right_context=right_context,
|
||||||
)
|
summary=summary,
|
||||||
|
memory=pre_memory,
|
||||||
|
left_context_key=left_context_key,
|
||||||
|
left_context_val=left_context_val,
|
||||||
|
)
|
||||||
state = self._pack_state(
|
state = self._pack_state(
|
||||||
next_key, next_val, utterance.size(0), memory, state
|
next_key, next_val, utterance.size(0), memory, state
|
||||||
)
|
)
|
||||||
@ -718,20 +771,22 @@ class EmformerLayer(nn.Module):
|
|||||||
layer_norm_utterance,
|
layer_norm_utterance,
|
||||||
layer_norm_right_context,
|
layer_norm_right_context,
|
||||||
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
||||||
output_right_context_utterance, output_memory = \
|
(
|
||||||
self._apply_attention_forward(
|
output_right_context_utterance,
|
||||||
layer_norm_utterance,
|
output_memory,
|
||||||
lengths,
|
) = self._apply_attention_forward(
|
||||||
layer_norm_right_context,
|
layer_norm_utterance,
|
||||||
memory,
|
lengths,
|
||||||
attention_mask,
|
layer_norm_right_context,
|
||||||
)
|
memory,
|
||||||
output_utterance, output_right_context = \
|
attention_mask,
|
||||||
self._apply_post_attention_ffn_layer_norm(
|
)
|
||||||
output_right_context_utterance,
|
(
|
||||||
utterance,
|
output_utterance,
|
||||||
right_context
|
output_right_context,
|
||||||
)
|
) = self._apply_post_attention_ffn_layer_norm(
|
||||||
|
output_right_context_utterance, utterance, right_context
|
||||||
|
)
|
||||||
return output_utterance, output_right_context, output_memory
|
return output_utterance, output_right_context, output_memory
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
@ -745,63 +800,66 @@ class EmformerLayer(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
1) Apply layer normalization on input utterance and right context
|
1) Apply layer normalization on input utterance and right context
|
||||||
before attention;
|
before attention;
|
||||||
2) Apply attention module with cached state, compute updated utterance,
|
2) Apply attention module with cached state, compute updated utterance,
|
||||||
right context, and memory, and update state;
|
right context, and memory, and update state;
|
||||||
3) Apply feed forward module and layer normalization on output utterance
|
3) Apply feed forward module and layer normalization on output
|
||||||
and right context.
|
utterance and right context.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
D: embedding dimension;
|
D: embedding dimension;
|
||||||
R: length of right_context;
|
R: length of right_context;
|
||||||
U: length of utterance;
|
U: length of utterance;
|
||||||
M: length of memory.
|
M: length of memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
utterance (torch.Tensor):
|
utterance (torch.Tensor):
|
||||||
Utterance frames, with shape (U, B, D).
|
Utterance frames, with shape (U, B, D).
|
||||||
lengths (torch.Tensor):
|
lengths (torch.Tensor):
|
||||||
With shape (B,) and i-th element representing
|
With shape (B,) and i-th element representing
|
||||||
number of valid frames for i-th batch element in utterance.
|
number of valid frames for i-th batch element in utterance.
|
||||||
right_context (torch.Tensor):
|
right_context (torch.Tensor):
|
||||||
Right context frames, with shape (R, B, D).
|
Right context frames, with shape (R, B, D).
|
||||||
memory (torch.Tensor):
|
memory (torch.Tensor):
|
||||||
Memory elements, with shape (M, B, D).
|
Memory elements, with shape (M, B, D).
|
||||||
state (List[torch.Tensor], optional):
|
state (List[torch.Tensor], optional):
|
||||||
List of tensors representing layer internal state generated in
|
List of tensors representing layer internal state generated in
|
||||||
preceding computation. (default=None)
|
preceding computation. (default=None)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
(Tensor, Tensor, List[torch.Tensor], Tensor):
|
||||||
- output utterance, with shape (U, B, D);
|
- output utterance, with shape (U, B, D);
|
||||||
- output right_context, with shape (R, B, D);
|
- output right_context, with shape (R, B, D);
|
||||||
- output memory, with shape (1, B, D) or (0, B, D).
|
- output memory, with shape (1, B, D) or (0, B, D).
|
||||||
- output state.
|
- output state.
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
layer_norm_utterance,
|
layer_norm_utterance,
|
||||||
layer_norm_right_context,
|
layer_norm_right_context,
|
||||||
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
) = self._apply_pre_attention_layer_norm(utterance, right_context)
|
||||||
output_right_context_utterance, output_memory, output_state = \
|
(
|
||||||
self._apply_attention_infer(
|
output_right_context_utterance,
|
||||||
layer_norm_utterance,
|
output_memory,
|
||||||
lengths,
|
output_state,
|
||||||
layer_norm_right_context,
|
) = self._apply_attention_infer(
|
||||||
memory,
|
layer_norm_utterance,
|
||||||
state
|
lengths,
|
||||||
)
|
layer_norm_right_context,
|
||||||
output_utterance, output_right_context = \
|
memory,
|
||||||
self._apply_post_attention_ffn_layer_norm(
|
state,
|
||||||
output_right_context_utterance,
|
)
|
||||||
utterance,
|
(
|
||||||
right_context
|
output_utterance,
|
||||||
)
|
output_right_context,
|
||||||
|
) = self._apply_post_attention_ffn_layer_norm(
|
||||||
|
output_right_context_utterance, utterance, right_context
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
output_utterance,
|
output_utterance,
|
||||||
output_right_context,
|
output_right_context,
|
||||||
output_memory,
|
output_memory,
|
||||||
output_state
|
output_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -895,7 +953,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
self.max_memory_size = max_memory_size
|
self.max_memory_size = max_memory_size
|
||||||
|
|
||||||
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Hard copy each chunk's right context and concat them. """
|
"""Hard copy each chunk's right context and concat them."""
|
||||||
T = x.shape[0]
|
T = x.shape[0]
|
||||||
num_segs = math.ceil(
|
num_segs = math.ceil(
|
||||||
(T - self.right_context_length) / self.chunk_length
|
(T - self.right_context_length) / self.chunk_length
|
||||||
@ -905,7 +963,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
start = (seg_idx + 1) * self.chunk_length
|
start = (seg_idx + 1) * self.chunk_length
|
||||||
end = start + self.right_context_length
|
end = start + self.right_context_length
|
||||||
right_context_blocks.append(x[start:end])
|
right_context_blocks.append(x[start:end])
|
||||||
right_context_blocks.append(x[T - self.right_context_length:])
|
last_right_context_start_idx = T - self.right_context_length
|
||||||
|
right_context_blocks.append(x[last_right_context_start_idx:])
|
||||||
return torch.cat(right_context_blocks)
|
return torch.cat(right_context_blocks)
|
||||||
|
|
||||||
def _gen_attention_mask_col_widths(
|
def _gen_attention_mask_col_widths(
|
||||||
@ -981,8 +1040,9 @@ class EmformerEncoder(nn.Module):
|
|||||||
num_cols = 9
|
num_cols = 9
|
||||||
# right context and utterance both attend to memory, right context,
|
# right context and utterance both attend to memory, right context,
|
||||||
# utterance
|
# utterance
|
||||||
right_context_utterance_cols_mask = \
|
right_context_utterance_cols_mask = [
|
||||||
[idx in [1, 4, 7] for idx in range(num_cols)]
|
idx in [1, 4, 7] for idx in range(num_cols)
|
||||||
|
]
|
||||||
# summary attends to right context, utterance
|
# summary attends to right context, utterance
|
||||||
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
|
||||||
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
|
||||||
@ -990,8 +1050,9 @@ class EmformerEncoder(nn.Module):
|
|||||||
num_cols = 6
|
num_cols = 6
|
||||||
# right context and utterance both attend to right context and
|
# right context and utterance both attend to right context and
|
||||||
# utterance
|
# utterance
|
||||||
right_context_utterance_cols_mask = \
|
right_context_utterance_cols_mask = [
|
||||||
[idx in [1, 4] for idx in range(num_cols)]
|
idx in [1, 4] for idx in range(num_cols)
|
||||||
|
]
|
||||||
summary_cols_mask = None
|
summary_cols_mask = None
|
||||||
masks_to_concat = [right_context_mask, utterance_mask]
|
masks_to_concat = [right_context_mask, utterance_mask]
|
||||||
|
|
||||||
@ -1002,7 +1063,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
col_widths,
|
col_widths,
|
||||||
right_context_utterance_cols_mask,
|
right_context_utterance_cols_mask,
|
||||||
self.right_context_length,
|
self.right_context_length,
|
||||||
utterance.device
|
utterance.device,
|
||||||
)
|
)
|
||||||
right_context_mask.append(right_context_mask_block)
|
right_context_mask.append(right_context_mask_block)
|
||||||
|
|
||||||
@ -1053,13 +1114,13 @@ class EmformerEncoder(nn.Module):
|
|||||||
right_context at the end.
|
right_context at the end.
|
||||||
"""
|
"""
|
||||||
right_context = self._gen_right_context(x)
|
right_context = self._gen_right_context(x)
|
||||||
utterance = x[:x.size(0) - self.right_context_length]
|
utterance = x[: x.size(0) - self.right_context_length]
|
||||||
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
|
||||||
attention_mask = self._gen_attention_mask(utterance)
|
attention_mask = self._gen_attention_mask(utterance)
|
||||||
memory = (
|
memory = (
|
||||||
self.init_memory_op(
|
self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
|
||||||
utterance.permute(1, 2, 0)
|
:-1
|
||||||
).permute(2, 0, 1)[:-1]
|
]
|
||||||
if self.use_memory
|
if self.use_memory
|
||||||
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
else torch.empty(0).to(dtype=x.dtype, device=x.device)
|
||||||
)
|
)
|
||||||
@ -1159,13 +1220,9 @@ class Emformer(EncoderInterface):
|
|||||||
self.subsampling_factor = subsampling_factor
|
self.subsampling_factor = subsampling_factor
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
"Support only 'subsampling_factor=4'."
|
|
||||||
)
|
|
||||||
if chunk_length % 4 != 0:
|
if chunk_length % 4 != 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("chunk_length must be a mutiple of 4.")
|
||||||
"chunk_length must be a mutiple of 4."
|
|
||||||
)
|
|
||||||
if left_context_length != 0 and left_context_length % 4 != 0:
|
if left_context_length != 0 and left_context_length % 4 != 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"left_context_length must be 0 or a mutiple of 4."
|
"left_context_length must be 0 or a mutiple of 4."
|
||||||
@ -1289,8 +1346,9 @@ class Emformer(EncoderInterface):
|
|||||||
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
x_lens = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == x_lens.max().item()
|
assert x.size(0) == x_lens.max().item()
|
||||||
|
|
||||||
output, output_lengths, output_states = \
|
output, output_lengths, output_states = self.encoder.infer(
|
||||||
self.encoder.infer(x, x_lens, states) # (T, N, C)
|
x, x_lens, states
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(output)
|
logits = self.encoder_output_layer(output)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user