mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Partial work
This commit is contained in:
parent
bcc9971ebe
commit
e51a2c9170
@ -36,6 +36,7 @@ from scaling import (
|
|||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
|
clip_grad,
|
||||||
convert_num_channels,
|
convert_num_channels,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -49,14 +50,15 @@ class Subformer(EncoderInterface):
|
|||||||
as downsampling_factor if they are single ints or one-element tuples. The length of
|
as downsampling_factor if they are single ints or one-element tuples. The length of
|
||||||
downsampling_factor defines the number of stacks.
|
downsampling_factor defines the number of stacks.
|
||||||
|
|
||||||
output_downsampling_factor (int): how much to downsample at the output. Note:
|
|
||||||
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
|
structure (str): determines the structure of the module, S is encoder stack,
|
||||||
You should probably leave this at 2.
|
open-parenthesis is downsampling operation, close-parenthesis is a corresponding
|
||||||
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
|
upsampling operation (but not all parentheses have to be closed if you want
|
||||||
Note: this is in addition to the downsampling factor of 2 that is applied in
|
the whole stack to downsample.)
|
||||||
the frontend (self.encoder_embed).
|
|
||||||
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
||||||
encoder stack.
|
encoder stack (i.e. one per "S" in structure).
|
||||||
|
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
|
||||||
|
operation (each open-parenthesis).
|
||||||
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
||||||
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
||||||
head: per stack, if a tuple..
|
head: per stack, if a tuple..
|
||||||
@ -80,13 +82,15 @@ class Subformer(EncoderInterface):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
|
structure: str = "S(S)S",
|
||||||
encoder_chunk_size: Union[int, Tuple[int]] = 128,
|
encoder_dim: Tuple[int, ...] = (384, 512, 384),
|
||||||
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
downsampling_factor: Tuple[int, ...] = (2,),
|
||||||
query_head_dim: Union[int, Tuple[int]] = 24,
|
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
|
||||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
|
||||||
num_heads: Union[int, Tuple[int]] = 8,
|
query_head_dim: Tuple[int, ...] = (24,),
|
||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
value_head_dim: Tuple[int, ...] = (12,),
|
||||||
|
num_heads: Tuple[int, ...] = (8,),
|
||||||
|
feedforward_dim: Tuple[int, ...] = (1536,),
|
||||||
memory_dim: int = -1,
|
memory_dim: int = -1,
|
||||||
pos_dim: int = 4,
|
pos_dim: int = 4,
|
||||||
dropout: Optional[FloatLike] = None, # see code below for default
|
dropout: Optional[FloatLike] = None, # see code below for default
|
||||||
@ -99,15 +103,20 @@ class Subformer(EncoderInterface):
|
|||||||
dropout = ScheduledFloat((0.0, 0.3),
|
dropout = ScheduledFloat((0.0, 0.3),
|
||||||
(20000.0, 0.1))
|
(20000.0, 0.1))
|
||||||
|
|
||||||
|
num_encoders = len([s for s in structure if s == 'S'])
|
||||||
|
num_downsamplers = len([s for s in structure if s == '('])
|
||||||
|
# when we upsample, we use the same downsampling object that we
|
||||||
|
# downsampled with, but we also need a BypassModule at that point.
|
||||||
|
num_bypass = len([s for s in structure if s == ')'])
|
||||||
|
|
||||||
def _to_tuple(x):
|
def _to_tuple(x):
|
||||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||||
as encoder_dim"""
|
as num_encoders"""
|
||||||
if isinstance(x, int):
|
assert isinstance(x, tuple)
|
||||||
x = (x,)
|
|
||||||
if len(x) == 1:
|
if len(x) == 1:
|
||||||
x = x * len(encoder_dim)
|
x = x * num_encoders
|
||||||
else:
|
else:
|
||||||
assert len(x) == len(encoder_dim) and isinstance(x[0], int)
|
assert len(x) == num_encoders
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.encoder_dim = encoder_dim
|
self.encoder_dim = encoder_dim
|
||||||
@ -120,8 +129,46 @@ class Subformer(EncoderInterface):
|
|||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
|
||||||
|
|
||||||
|
if len(downsampling_factor) == 1:
|
||||||
|
downsampling_factor = downsampling_factor * num_downsamplers
|
||||||
|
assert len(downsampling_factor) == num_downsamplers
|
||||||
|
|
||||||
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||||
encoders = []
|
encoders = []
|
||||||
|
downsamplers = []
|
||||||
|
bypass = []
|
||||||
|
|
||||||
|
for s in structure:
|
||||||
|
if s == 'S':
|
||||||
|
i = len(encoders)
|
||||||
|
encoder_layer = SubformerEncoderLayer(
|
||||||
|
embed_dim=encoder_dim[i],
|
||||||
|
pos_dim=pos_dim,
|
||||||
|
num_heads=num_heads[i],
|
||||||
|
query_head_dim=query_head_dim[i],
|
||||||
|
value_head_dim=value_head_dim[i],
|
||||||
|
feedforward_dim=feedforward_dim[i],
|
||||||
|
memory_dim=memory_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder = SubformerEncoder(
|
||||||
|
encoder_layer,
|
||||||
|
num_encoder_layers[i],
|
||||||
|
dropout=dropout,
|
||||||
|
chunk_size=encoder_chunk_size[i],
|
||||||
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||||
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||||
|
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||||
|
)
|
||||||
|
encoders.append(encoder)
|
||||||
|
|
||||||
|
pass
|
||||||
|
elif s =='(':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
assert s == ')'
|
||||||
|
|
||||||
|
|
||||||
num_encoders = len(encoder_dim)
|
num_encoders = len(encoder_dim)
|
||||||
@ -132,31 +179,6 @@ class Subformer(EncoderInterface):
|
|||||||
|
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
|
||||||
encoder_layer = SubformerEncoderLayer(
|
|
||||||
embed_dim=encoder_dim[i],
|
|
||||||
pos_dim=pos_dim,
|
|
||||||
num_heads=num_heads[i],
|
|
||||||
query_head_dim=query_head_dim[i],
|
|
||||||
value_head_dim=value_head_dim[i],
|
|
||||||
feedforward_dim=feedforward_dim[i],
|
|
||||||
memory_dim=memory_dim,
|
|
||||||
dropout=dropout,
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
|
||||||
# layer learn something. Then we start to warm up the other encoders.
|
|
||||||
encoder = SubformerEncoder(
|
|
||||||
encoder_layer,
|
|
||||||
num_encoder_layers[i],
|
|
||||||
dropout=dropout,
|
|
||||||
chunk_size=encoder_chunk_size[i],
|
|
||||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
|
||||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
|
||||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
|
||||||
)
|
|
||||||
|
|
||||||
encoders.append(encoder)
|
|
||||||
|
|
||||||
mid = len(encoders) // 2
|
mid = len(encoders) // 2
|
||||||
encoder = DownsampledSubformerEncoder(
|
encoder = DownsampledSubformerEncoder(
|
||||||
@ -567,13 +589,13 @@ class SubformerEncoder(nn.Module):
|
|||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
chunk_size: int = 256,
|
chunk_sizes: Tuple[int, ...] = (128, 2048),
|
||||||
initial_layerdrop_rate: float = 0.5,
|
initial_layerdrop_rate: float = 0.5,
|
||||||
final_layerdrop_rate: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.chunk_size = chunk_size
|
self.chunk_sizes = chunk_sizes
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
@ -668,13 +690,17 @@ class SubformerEncoder(nn.Module):
|
|||||||
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
|
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
|
||||||
"""
|
"""
|
||||||
seq_len = src.shape[0]
|
seq_len = src.shape[0]
|
||||||
if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0:
|
chunk_indexes = []
|
||||||
return [ seq_len ], [ 0 ] * len(self.layers)
|
chunk_sizes = []
|
||||||
else:
|
for i, chunk_size in enumerate(self.chunk_sizes):
|
||||||
num_layers = len(self.layers)
|
chunk_sizes.append(chunk_size if seq_len % chunk_size == 0
|
||||||
chunk_indexes = [0, 1] * (num_layers + 1 // 2)
|
else seq_len)
|
||||||
return [ self.chunk_size, seq_len ], chunk_indexes[:num_layers]
|
|
||||||
|
|
||||||
|
num_chunk_sizes = len(self.chunk_sizes)
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
chunk_indexes.append(i % num_chunk_sizes)
|
||||||
|
|
||||||
|
return chunk_sizes, chunk_indexes
|
||||||
|
|
||||||
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
|
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -809,6 +835,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
downsampling_factor: int):
|
downsampling_factor: int):
|
||||||
|
assert downsampling_factor > 1
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -864,9 +891,8 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
d = self.downsampling_factor
|
d = self.downsampling_factor
|
||||||
seq_len_reduced = (seq_len + d - 1) // d
|
seq_len_reduced = (seq_len + d - 1) // d
|
||||||
|
|
||||||
|
|
||||||
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
||||||
missing = weights_discarded.shape[1] - seq_len_reduced
|
missing = seq_len_reduced - weights_discarded.shape[1]
|
||||||
if missing != 0:
|
if missing != 0:
|
||||||
weights_discarded = torch.cat((weights_discarded,
|
weights_discarded = torch.cat((weights_discarded,
|
||||||
torch.zeros(batch_size, missing,
|
torch.zeros(batch_size, missing,
|
||||||
@ -986,6 +1012,12 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
||||||
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
# it's possible to get large gradients at this point; clip these at
|
||||||
|
# this point to reduce the extent to which it has to reduce the
|
||||||
|
# grad_scale.
|
||||||
|
weights = clip_grad(weights, 5000.0)
|
||||||
|
|
||||||
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
|
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
|
||||||
batch_size, seq_len_reduced, seq_len))
|
batch_size, seq_len_reduced, seq_len))
|
||||||
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
|
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
|
||||||
@ -1849,7 +1881,7 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
memory_dim=memory_dim,
|
memory_dim=memory_dim,
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 128
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(
|
f = c(
|
||||||
torch.randn(seq_len, batch_size, 64),
|
torch.randn(seq_len, batch_size, 64),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user