update emformer.py

This commit is contained in:
yaozengwei 2022-04-08 20:31:32 +08:00
parent 2d1b90f758
commit d58002c414
2 changed files with 233 additions and 174 deletions

View File

@ -4,6 +4,7 @@ repos:
hooks: hooks:
- id: black - id: black
args: [--line-length=80] args: [--line-length=80]
additional_dependencies: ['click==8.0.1']
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 3.9.2 rev: 3.9.2

View File

@ -1,3 +1,22 @@
# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# It is modified based on
# https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py.
import math import math
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -22,29 +41,32 @@ def _get_activation_module(activation: str) -> nn.Module:
def _get_weight_init_gains( def _get_weight_init_gains(
weight_init_scale_strategy: Optional[str], weight_init_scale_strategy: Optional[str], num_layers: int
num_layers: int
) -> List[Optional[float]]: ) -> List[Optional[float]]:
if weight_init_scale_strategy is None: if weight_init_scale_strategy is None:
return [None for _ in range(num_layers)] return [None for _ in range(num_layers)]
elif weight_init_scale_strategy == "depthwise": elif weight_init_scale_strategy == "depthwise":
return [1.0 / math.sqrt(layer_idx + 1) return [
for layer_idx in range(num_layers)] 1.0 / math.sqrt(layer_idx + 1) for layer_idx in range(num_layers)
]
elif weight_init_scale_strategy == "constant": elif weight_init_scale_strategy == "constant":
return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)] return [1.0 / math.sqrt(2) for layer_idx in range(num_layers)]
else: else:
raise ValueError(f"Unsupported weight_init_scale_strategy value" raise ValueError(
f"{weight_init_scale_strategy}") f"Unsupported weight_init_scale_strategy value"
f"{weight_init_scale_strategy}"
)
def _gen_attention_mask_block( def _gen_attention_mask_block(
col_widths: List[int], col_widths: List[int],
col_mask: List[bool], col_mask: List[bool],
num_rows: int, num_rows: int,
device: torch.device device: torch.device,
) -> torch.Tensor: ) -> torch.Tensor:
assert len(col_widths) == len(col_mask), ( assert len(col_widths) == len(
"Length of col_widths must match that of col_mask") col_mask
), "Length of col_widths must match that of col_mask"
mask_block = [ mask_block = [
torch.ones(num_rows, col_width, device=device) torch.ones(num_rows, col_width, device=device)
@ -99,9 +121,7 @@ class EmformerAttention(nn.Module):
self.scaling = (self.embed_dim // self.nhead) ** -0.5 self.scaling = (self.embed_dim // self.nhead) ** -0.5
self.emb_to_key_value = nn.Linear( self.emb_to_key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
embed_dim, 2 * embed_dim, bias=True
)
self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True) self.emb_to_query = nn.Linear(embed_dim, embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
@ -119,7 +139,7 @@ class EmformerAttention(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor], padding_mask: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
""" Given the entire attention weights, mask out unecessary connections """Given the entire attention weights, mask out unecessary connections
and optionally with padding positions, to obtain underlying chunk-wise and optionally with padding positions, to obtain underlying chunk-wise
attention probabilities. attention probabilities.
@ -154,7 +174,7 @@ class EmformerAttention(nn.Module):
) )
attention_weights_float = attention_weights_float.masked_fill( attention_weights_float = attention_weights_float.masked_fill(
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
self.negative_inf self.negative_inf,
) )
attention_weights_float = attention_weights_float.view( attention_weights_float = attention_weights_float.view(
B * self.nhead, Q, -1 B * self.nhead, Q, -1
@ -164,9 +184,7 @@ class EmformerAttention(nn.Module):
attention_weights_float, dim=-1 attention_weights_float, dim=-1
).type_as(attention_weights) ).type_as(attention_weights)
attention_probs = nn.functional.dropout( attention_probs = nn.functional.dropout(
attention_probs, attention_probs, p=float(self.dropout), training=self.training
p=float(self.dropout),
training=self.training
) )
return attention_probs return attention_probs
@ -181,7 +199,7 @@ class EmformerAttention(nn.Module):
left_context_key: Optional[torch.Tensor] = None, left_context_key: Optional[torch.Tensor] = None,
left_context_val: Optional[torch.Tensor] = None, left_context_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" Underlying chunk-wise attention implementation. """Underlying chunk-wise attention implementation.
L: length of left_context; L: length of left_context;
S: length of summary; S: length of summary;
@ -242,14 +260,28 @@ class EmformerAttention(nn.Module):
# [mems, right context, left context, uttrance] # [mems, right context, left context, uttrance]
M = memory.size(0) M = memory.size(0)
R = right_context.size(0) R = right_context.size(0)
key = torch.cat([key[:M + R], left_context_key, key[M + R:]]) right_context_end_idx = M + R
value = torch.cat([value[:M + R], left_context_val, value[M + R:]]) key = torch.cat(
[
key[:right_context_end_idx],
left_context_key,
key[right_context_end_idx:],
]
)
value = torch.cat(
[
value[:right_context_end_idx],
left_context_val,
value[right_context_end_idx:],
]
)
# Compute attention weights from query, key, and value. # Compute attention weights from query, key, and value.
reshaped_query, reshaped_key, reshaped_value = [ reshaped_query, reshaped_key, reshaped_value = [
tensor.contiguous().view( tensor.contiguous()
-1, B * self.nhead, self.embed_dim // self.nhead .view(-1, B * self.nhead, self.embed_dim // self.nhead)
).transpose(0, 1) for tensor in [query, key, value] .transpose(0, 1)
for tensor in [query, key, value]
] ]
attention_weights = torch.bmm( attention_weights = torch.bmm(
reshaped_query * self.scaling, reshaped_key.transpose(1, 2) reshaped_query * self.scaling, reshaped_key.transpose(1, 2)
@ -272,18 +304,21 @@ class EmformerAttention(nn.Module):
attention = torch.bmm(attention_probs, reshaped_value) attention = torch.bmm(attention_probs, reshaped_value)
Q = query.size(0) Q = query.size(0)
assert attention.shape == ( assert attention.shape == (
B * self.nhead, Q, self.embed_dim // self.nhead, B * self.nhead,
Q,
self.embed_dim // self.nhead,
) )
attention = attention.transpose(0, 1).contiguous().view( attention = (
Q, B, self.embed_dim attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim)
) )
# Apply output projection. # Apply output projection.
outputs = self.out_proj(attention) outputs = self.out_proj(attention)
S = summary.size(0) S = summary.size(0)
output_right_context_utterance = outputs[:Q - S] summary_start_idx = Q - S
output_memory = outputs[Q - S:] output_right_context_utterance = outputs[:summary_start_idx]
output_memory = outputs[summary_start_idx:]
if self.tanh_on_mem: if self.tanh_on_mem:
output_memory = torch.tanh(output_memory) output_memory = torch.tanh(output_memory)
else: else:
@ -331,15 +366,14 @@ class EmformerAttention(nn.Module):
- output of right context and utterance, with shape (R + U, B, D). - output of right context and utterance, with shape (R + U, B, D).
- memory output, with shape (M, B, D), where M = S - 1 or M = 0. - memory output, with shape (M, B, D), where M = S - 1 or M = 0.
""" """
output_right_context_utterance, output_memory, _, _ = \ (
self._forward_impl( output_right_context_utterance,
utterance, output_memory,
lengths, _,
right_context, _,
summary, ) = self._forward_impl(
memory, utterance, lengths, right_context, summary, memory, attention_mask
attention_mask )
)
return output_right_context_utterance, output_memory[:-1] return output_right_context_utterance, output_memory[:-1]
@torch.jit.export @torch.jit.export
@ -394,29 +428,38 @@ class EmformerAttention(nn.Module):
# query: [right context, utterance, summary] # query: [right context, utterance, summary]
Q = right_context.size(0) + utterance.size(0) + summary.size(0) Q = right_context.size(0) + utterance.size(0) + summary.size(0)
# key, value: [memory, right context, left context, uttrance] # key, value: [memory, right context, left context, uttrance]
KV = memory.size(0) + right_context.size(0) + \ KV = (
left_context_key.size(0) + utterance.size(0) memory.size(0)
attention_mask = torch.zeros( + right_context.size(0)
Q, KV + left_context_key.size(0)
).to(dtype=torch.bool, device=utterance.device) + utterance.size(0)
)
attention_mask = torch.zeros(Q, KV).to(
dtype=torch.bool, device=utterance.device
)
# Disallow attention bettween the summary vector with the memory bank # Disallow attention bettween the summary vector with the memory bank
attention_mask[-1, :memory.size(0)] = True attention_mask[-1, : memory.size(0)] = True
output_right_context_utterance, output_memory, key, value = \ (
self._forward_impl( output_right_context_utterance,
utterance, output_memory,
lengths, key,
right_context, value,
summary, ) = self._forward_impl(
memory, utterance,
attention_mask, lengths,
left_context_key=left_context_key, right_context,
left_context_val=left_context_val, summary,
) memory,
attention_mask,
left_context_key=left_context_key,
left_context_val=left_context_val,
)
right_context_end_idx = memory.size(0) + right_context.size(0)
return ( return (
output_right_context_utterance, output_right_context_utterance,
output_memory, output_memory,
key[memory.size(0) + right_context.size(0):], key[right_context_end_idx:],
value[memory.size(0) + right_context.size(0):], value[right_context_end_idx:],
) )
@ -499,9 +542,7 @@ class EmformerLayer(nn.Module):
self.use_memory = max_memory_size > 0 self.use_memory = max_memory_size > 0
def _init_state( def _init_state(
self, self, batch_size: int, device: Optional[torch.device]
batch_size: int,
device: Optional[torch.device]
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Initialize states with zeros.""" """Initialize states with zeros."""
empty_memory = torch.zeros( empty_memory = torch.zeros(
@ -519,8 +560,7 @@ class EmformerLayer(nn.Module):
return [empty_memory, left_context_key, left_context_val, past_length] return [empty_memory, left_context_key, left_context_val, past_length]
def _unpack_state( def _unpack_state(
self, self, state: List[torch.Tensor]
state: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Unpack cached states including: """Unpack cached states including:
1) output memory from previous chunks in the lower layer; 1) output memory from previous chunks in the lower layer;
@ -532,11 +572,13 @@ class EmformerLayer(nn.Module):
past_memory_length = min( past_memory_length = min(
self.max_memory_size, math.ceil(past_length / self.chunk_length) self.max_memory_size, math.ceil(past_length / self.chunk_length)
) )
pre_memory = state[0][self.max_memory_size - past_memory_length:] memory_start_idx = self.max_memory_size - past_memory_length
left_context_key = \ pre_memory = state[0][memory_start_idx:]
state[1][self.left_context_length - past_left_context_length:] left_context_start_idx = (
left_context_val = \ self.left_context_length - past_left_context_length
state[2][self.left_context_length - past_left_context_length:] )
left_context_key = state[1][left_context_start_idx:]
left_context_val = state[2][left_context_start_idx:]
return pre_memory, left_context_key, left_context_val return pre_memory, left_context_key, left_context_val
def _pack_state( def _pack_state(
@ -556,40 +598,46 @@ class EmformerLayer(nn.Module):
new_memory = torch.cat([state[0], memory]) new_memory = torch.cat([state[0], memory])
new_key = torch.cat([state[1], next_key]) new_key = torch.cat([state[1], next_key])
new_val = torch.cat([state[2], next_val]) new_val = torch.cat([state[2], next_val])
state[0] = new_memory[new_memory.size(0) - self.max_memory_size:] memory_start_idx = new_memory.size(0) - self.max_memory_size
state[1] = new_key[new_key.size(0) - self.left_context_length:] state[0] = new_memory[memory_start_idx:]
state[2] = new_val[new_val.size(0) - self.left_context_length:] key_start_idx = new_key.size(0) - self.left_context_length
state[1] = new_key[key_start_idx:]
val_start_idx = new_val.size(0) - self.left_context_length
state[2] = new_val[val_start_idx:]
state[3] = state[3] + update_length state[3] = state[3] + update_length
return state return state
def _apply_pre_attention_layer_norm( def _apply_pre_attention_layer_norm(
self, utterance: torch.Tensor, right_context: torch.Tensor self, utterance: torch.Tensor, right_context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply layer normalization before attention. """ """Apply layer normalization before attention."""
layer_norm_input = self.layer_norm_input( layer_norm_input = self.layer_norm_input(
torch.cat([right_context, utterance]) torch.cat([right_context, utterance])
) )
layer_norm_utterance = layer_norm_input[right_context.size(0):] right_context_end_idx = right_context.size(0)
layer_norm_right_context = layer_norm_input[:right_context.size(0)] layer_norm_utterance = layer_norm_input[right_context_end_idx:]
layer_norm_right_context = layer_norm_input[:right_context_end_idx]
return layer_norm_utterance, layer_norm_right_context return layer_norm_utterance, layer_norm_right_context
def _apply_post_attention_ffn_layer_norm( def _apply_post_attention_ffn_layer_norm(
self, self,
output_right_context_utterance: torch.Tensor, output_right_context_utterance: torch.Tensor,
utterance: torch.Tensor, utterance: torch.Tensor,
right_context: torch.Tensor right_context: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply feed forward and layer normalization after attention.""" """Apply feed forward and layer normalization after attention."""
# Apply residual connection between input and attention output. # Apply residual connection between input and attention output.
result = self.dropout(output_right_context_utterance) + \ result = self.dropout(output_right_context_utterance) + torch.cat(
torch.cat([right_context, utterance]) [right_context, utterance]
)
# Apply feedforward module and residual connection. # Apply feedforward module and residual connection.
result = self.pos_ff(result) + result result = self.pos_ff(result) + result
# Apply layer normalization for output. # Apply layer normalization for output.
result = self.layer_norm_output(result) result = self.layer_norm_output(result)
output_utterance = result[right_context.size(0):] right_context_end_idx = right_context.size(0)
output_right_context = result[:right_context.size(0)] output_utterance = result[right_context_end_idx:]
output_right_context = result[:right_context_end_idx]
return output_utterance, output_right_context return output_utterance, output_right_context
def _apply_attention_forward( def _apply_attention_forward(
@ -600,16 +648,16 @@ class EmformerLayer(nn.Module):
memory: torch.Tensor, memory: torch.Tensor,
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply attention in non-infer mode. """ """Apply attention in non-infer mode."""
if attention_mask is None: if attention_mask is None:
raise ValueError( raise ValueError(
"attention_mask must be not None in non-infer mode. " "attention_mask must be not None in non-infer mode. "
) )
if self.use_memory: if self.use_memory:
summary = self.summary_op( summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
utterance.permute(1, 2, 0) 2, 0, 1
).permute(2, 0, 1) )
else: else:
summary = torch.empty(0).to( summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
@ -646,27 +694,32 @@ class EmformerLayer(nn.Module):
""" """
if state is None: if state is None:
state = self._init_state(utterance.size(1), device=utterance.device) state = self._init_state(utterance.size(1), device=utterance.device)
pre_memory, left_context_key, left_context_val = \ pre_memory, left_context_key, left_context_val = self._unpack_state(
self._unpack_state(state) state
)
if self.use_memory: if self.use_memory:
summary = self.summary_op( summary = self.summary_op(utterance.permute(1, 2, 0)).permute(
utterance.permute(1, 2, 0) 2, 0, 1
).permute(2, 0, 1) )
summary = summary[:1] summary = summary[:1]
else: else:
summary = torch.empty(0).to( summary = torch.empty(0).to(
dtype=utterance.dtype, device=utterance.device dtype=utterance.dtype, device=utterance.device
) )
output_right_context_utterance, output_memory, next_key, next_val = \ (
self.attention.infer( output_right_context_utterance,
utterance=utterance, output_memory,
lengths=lengths, next_key,
right_context=right_context, next_val,
summary=summary, ) = self.attention.infer(
memory=pre_memory, utterance=utterance,
left_context_key=left_context_key, lengths=lengths,
left_context_val=left_context_val, right_context=right_context,
) summary=summary,
memory=pre_memory,
left_context_key=left_context_key,
left_context_val=left_context_val,
)
state = self._pack_state( state = self._pack_state(
next_key, next_val, utterance.size(0), memory, state next_key, next_val, utterance.size(0), memory, state
) )
@ -718,20 +771,22 @@ class EmformerLayer(nn.Module):
layer_norm_utterance, layer_norm_utterance,
layer_norm_right_context, layer_norm_right_context,
) = self._apply_pre_attention_layer_norm(utterance, right_context) ) = self._apply_pre_attention_layer_norm(utterance, right_context)
output_right_context_utterance, output_memory = \ (
self._apply_attention_forward( output_right_context_utterance,
layer_norm_utterance, output_memory,
lengths, ) = self._apply_attention_forward(
layer_norm_right_context, layer_norm_utterance,
memory, lengths,
attention_mask, layer_norm_right_context,
) memory,
output_utterance, output_right_context = \ attention_mask,
self._apply_post_attention_ffn_layer_norm( )
output_right_context_utterance, (
utterance, output_utterance,
right_context output_right_context,
) ) = self._apply_post_attention_ffn_layer_norm(
output_right_context_utterance, utterance, right_context
)
return output_utterance, output_right_context, output_memory return output_utterance, output_right_context, output_memory
@torch.jit.export @torch.jit.export
@ -745,63 +800,66 @@ class EmformerLayer(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""Forward pass for inference. """Forward pass for inference.
1) Apply layer normalization on input utterance and right context 1) Apply layer normalization on input utterance and right context
before attention; before attention;
2) Apply attention module with cached state, compute updated utterance, 2) Apply attention module with cached state, compute updated utterance,
right context, and memory, and update state; right context, and memory, and update state;
3) Apply feed forward module and layer normalization on output utterance 3) Apply feed forward module and layer normalization on output
and right context. utterance and right context.
B: batch size; B: batch size;
D: embedding dimension; D: embedding dimension;
R: length of right_context; R: length of right_context;
U: length of utterance; U: length of utterance;
M: length of memory. M: length of memory.
Args: Args:
utterance (torch.Tensor): utterance (torch.Tensor):
Utterance frames, with shape (U, B, D). Utterance frames, with shape (U, B, D).
lengths (torch.Tensor): lengths (torch.Tensor):
With shape (B,) and i-th element representing With shape (B,) and i-th element representing
number of valid frames for i-th batch element in utterance. number of valid frames for i-th batch element in utterance.
right_context (torch.Tensor): right_context (torch.Tensor):
Right context frames, with shape (R, B, D). Right context frames, with shape (R, B, D).
memory (torch.Tensor): memory (torch.Tensor):
Memory elements, with shape (M, B, D). Memory elements, with shape (M, B, D).
state (List[torch.Tensor], optional): state (List[torch.Tensor], optional):
List of tensors representing layer internal state generated in List of tensors representing layer internal state generated in
preceding computation. (default=None) preceding computation. (default=None)
Returns: Returns:
(Tensor, Tensor, List[torch.Tensor], Tensor): (Tensor, Tensor, List[torch.Tensor], Tensor):
- output utterance, with shape (U, B, D); - output utterance, with shape (U, B, D);
- output right_context, with shape (R, B, D); - output right_context, with shape (R, B, D);
- output memory, with shape (1, B, D) or (0, B, D). - output memory, with shape (1, B, D) or (0, B, D).
- output state. - output state.
""" """
( (
layer_norm_utterance, layer_norm_utterance,
layer_norm_right_context, layer_norm_right_context,
) = self._apply_pre_attention_layer_norm(utterance, right_context) ) = self._apply_pre_attention_layer_norm(utterance, right_context)
output_right_context_utterance, output_memory, output_state = \ (
self._apply_attention_infer( output_right_context_utterance,
layer_norm_utterance, output_memory,
lengths, output_state,
layer_norm_right_context, ) = self._apply_attention_infer(
memory, layer_norm_utterance,
state lengths,
) layer_norm_right_context,
output_utterance, output_right_context = \ memory,
self._apply_post_attention_ffn_layer_norm( state,
output_right_context_utterance, )
utterance, (
right_context output_utterance,
) output_right_context,
) = self._apply_post_attention_ffn_layer_norm(
output_right_context_utterance, utterance, right_context
)
return ( return (
output_utterance, output_utterance,
output_right_context, output_right_context,
output_memory, output_memory,
output_state output_state,
) )
@ -895,7 +953,7 @@ class EmformerEncoder(nn.Module):
self.max_memory_size = max_memory_size self.max_memory_size = max_memory_size
def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor:
"""Hard copy each chunk's right context and concat them. """ """Hard copy each chunk's right context and concat them."""
T = x.shape[0] T = x.shape[0]
num_segs = math.ceil( num_segs = math.ceil(
(T - self.right_context_length) / self.chunk_length (T - self.right_context_length) / self.chunk_length
@ -905,7 +963,8 @@ class EmformerEncoder(nn.Module):
start = (seg_idx + 1) * self.chunk_length start = (seg_idx + 1) * self.chunk_length
end = start + self.right_context_length end = start + self.right_context_length
right_context_blocks.append(x[start:end]) right_context_blocks.append(x[start:end])
right_context_blocks.append(x[T - self.right_context_length:]) last_right_context_start_idx = T - self.right_context_length
right_context_blocks.append(x[last_right_context_start_idx:])
return torch.cat(right_context_blocks) return torch.cat(right_context_blocks)
def _gen_attention_mask_col_widths( def _gen_attention_mask_col_widths(
@ -981,8 +1040,9 @@ class EmformerEncoder(nn.Module):
num_cols = 9 num_cols = 9
# right context and utterance both attend to memory, right context, # right context and utterance both attend to memory, right context,
# utterance # utterance
right_context_utterance_cols_mask = \ right_context_utterance_cols_mask = [
[idx in [1, 4, 7] for idx in range(num_cols)] idx in [1, 4, 7] for idx in range(num_cols)
]
# summary attends to right context, utterance # summary attends to right context, utterance
summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)] summary_cols_mask = [idx in [4, 7] for idx in range(num_cols)]
masks_to_concat = [right_context_mask, utterance_mask, summary_mask] masks_to_concat = [right_context_mask, utterance_mask, summary_mask]
@ -990,8 +1050,9 @@ class EmformerEncoder(nn.Module):
num_cols = 6 num_cols = 6
# right context and utterance both attend to right context and # right context and utterance both attend to right context and
# utterance # utterance
right_context_utterance_cols_mask = \ right_context_utterance_cols_mask = [
[idx in [1, 4] for idx in range(num_cols)] idx in [1, 4] for idx in range(num_cols)
]
summary_cols_mask = None summary_cols_mask = None
masks_to_concat = [right_context_mask, utterance_mask] masks_to_concat = [right_context_mask, utterance_mask]
@ -1002,7 +1063,7 @@ class EmformerEncoder(nn.Module):
col_widths, col_widths,
right_context_utterance_cols_mask, right_context_utterance_cols_mask,
self.right_context_length, self.right_context_length,
utterance.device utterance.device,
) )
right_context_mask.append(right_context_mask_block) right_context_mask.append(right_context_mask_block)
@ -1053,13 +1114,13 @@ class EmformerEncoder(nn.Module):
right_context at the end. right_context at the end.
""" """
right_context = self._gen_right_context(x) right_context = self._gen_right_context(x)
utterance = x[:x.size(0) - self.right_context_length] utterance = x[: x.size(0) - self.right_context_length]
output_lengths = torch.clamp(lengths - self.right_context_length, min=0) output_lengths = torch.clamp(lengths - self.right_context_length, min=0)
attention_mask = self._gen_attention_mask(utterance) attention_mask = self._gen_attention_mask(utterance)
memory = ( memory = (
self.init_memory_op( self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[
utterance.permute(1, 2, 0) :-1
).permute(2, 0, 1)[:-1] ]
if self.use_memory if self.use_memory
else torch.empty(0).to(dtype=x.dtype, device=x.device) else torch.empty(0).to(dtype=x.dtype, device=x.device)
) )
@ -1159,13 +1220,9 @@ class Emformer(EncoderInterface):
self.subsampling_factor = subsampling_factor self.subsampling_factor = subsampling_factor
self.right_context_length = right_context_length self.right_context_length = right_context_length
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError( raise NotImplementedError("Support only 'subsampling_factor=4'.")
"Support only 'subsampling_factor=4'."
)
if chunk_length % 4 != 0: if chunk_length % 4 != 0:
raise NotImplementedError( raise NotImplementedError("chunk_length must be a mutiple of 4.")
"chunk_length must be a mutiple of 4."
)
if left_context_length != 0 and left_context_length % 4 != 0: if left_context_length != 0 and left_context_length % 4 != 0:
raise NotImplementedError( raise NotImplementedError(
"left_context_length must be 0 or a mutiple of 4." "left_context_length must be 0 or a mutiple of 4."
@ -1289,8 +1346,9 @@ class Emformer(EncoderInterface):
x_lens = ((x_lens - 1) // 2 - 1) // 2 x_lens = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == x_lens.max().item() assert x.size(0) == x_lens.max().item()
output, output_lengths, output_states = \ output, output_lengths, output_states = self.encoder.infer(
self.encoder.infer(x, x_lens, states) # (T, N, C) x, x_lens, states
) # (T, N, C)
logits = self.encoder_output_layer(output) logits = self.encoder_output_layer(output)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)