mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
output from middle layer
This commit is contained in:
parent
8d73423a29
commit
9ee57959ec
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
@ -61,6 +61,7 @@ class Conformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
extra_output_layer: int = None, # Default no distillation loss.
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -86,7 +87,18 @@ class Conformer(EncoderInterface):
|
|||||||
layer_dropout,
|
layer_dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
|
||||||
|
# The last layer is always needed.
|
||||||
|
self.output_layers = [num_encoder_layers - 1]
|
||||||
|
if extra_output_layer is not None:
|
||||||
|
assert (
|
||||||
|
extra_output_layer >= 0
|
||||||
|
and extra_output_layer < num_encoder_layers - 1
|
||||||
|
)
|
||||||
|
self.output_layers.insert(extra_output_layer, 0)
|
||||||
|
self.encoder = ConformerEncoder(
|
||||||
|
encoder_layer, num_encoder_layers, output_layers=self.output_layers
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -119,12 +131,18 @@ class Conformer(EncoderInterface):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(
|
layers_result = self.encoder(
|
||||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
# layers_result[0] is the output from a middle layer for distillation.
|
||||||
|
# layers_result[-1] is the output from the final layer for RNN-T loss.
|
||||||
|
x = layers_result[-1]
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
if len(self.output_layers) > 1:
|
||||||
|
assert len(self.output_layers) == len(layers_result)
|
||||||
|
return x, lengths, layers_result[0]
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
@ -276,12 +294,18 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> out = conformer_encoder(src, pos_emb)
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_layer: nn.Module,
|
||||||
|
num_layers: int,
|
||||||
|
output_layers: List[int],
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.output_layers = output_layers
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -290,7 +314,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
) -> List[Tensor]:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -309,6 +333,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
|
layers_result = []
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
@ -317,8 +342,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
if i in self.output_layers:
|
||||||
|
layers_result.append(output)
|
||||||
|
|
||||||
return output
|
return layers_result
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user