mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Partial work
This commit is contained in:
parent
bcc9971ebe
commit
e51a2c9170
@ -36,6 +36,7 @@ from scaling import (
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
clip_grad,
|
||||
convert_num_channels,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
@ -49,14 +50,15 @@ class Subformer(EncoderInterface):
|
||||
as downsampling_factor if they are single ints or one-element tuples. The length of
|
||||
downsampling_factor defines the number of stacks.
|
||||
|
||||
output_downsampling_factor (int): how much to downsample at the output. Note:
|
||||
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
|
||||
You should probably leave this at 2.
|
||||
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
|
||||
Note: this is in addition to the downsampling factor of 2 that is applied in
|
||||
the frontend (self.encoder_embed).
|
||||
|
||||
structure (str): determines the structure of the module, S is encoder stack,
|
||||
open-parenthesis is downsampling operation, close-parenthesis is a corresponding
|
||||
upsampling operation (but not all parentheses have to be closed if you want
|
||||
the whole stack to downsample.)
|
||||
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
||||
encoder stack.
|
||||
encoder stack (i.e. one per "S" in structure).
|
||||
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
|
||||
operation (each open-parenthesis).
|
||||
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
||||
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
||||
head: per stack, if a tuple..
|
||||
@ -80,13 +82,15 @@ class Subformer(EncoderInterface):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
|
||||
encoder_chunk_size: Union[int, Tuple[int]] = 128,
|
||||
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
||||
query_head_dim: Union[int, Tuple[int]] = 24,
|
||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
structure: str = "S(S)S",
|
||||
encoder_dim: Tuple[int, ...] = (384, 512, 384),
|
||||
downsampling_factor: Tuple[int, ...] = (2,),
|
||||
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
|
||||
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
|
||||
query_head_dim: Tuple[int, ...] = (24,),
|
||||
value_head_dim: Tuple[int, ...] = (12,),
|
||||
num_heads: Tuple[int, ...] = (8,),
|
||||
feedforward_dim: Tuple[int, ...] = (1536,),
|
||||
memory_dim: int = -1,
|
||||
pos_dim: int = 4,
|
||||
dropout: Optional[FloatLike] = None, # see code below for default
|
||||
@ -99,15 +103,20 @@ class Subformer(EncoderInterface):
|
||||
dropout = ScheduledFloat((0.0, 0.3),
|
||||
(20000.0, 0.1))
|
||||
|
||||
num_encoders = len([s for s in structure if s == 'S'])
|
||||
num_downsamplers = len([s for s in structure if s == '('])
|
||||
# when we upsample, we use the same downsampling object that we
|
||||
# downsampled with, but we also need a BypassModule at that point.
|
||||
num_bypass = len([s for s in structure if s == ')'])
|
||||
|
||||
def _to_tuple(x):
|
||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||
as encoder_dim"""
|
||||
if isinstance(x, int):
|
||||
x = (x,)
|
||||
as num_encoders"""
|
||||
assert isinstance(x, tuple)
|
||||
if len(x) == 1:
|
||||
x = x * len(encoder_dim)
|
||||
x = x * num_encoders
|
||||
else:
|
||||
assert len(x) == len(encoder_dim) and isinstance(x[0], int)
|
||||
assert len(x) == num_encoders
|
||||
return x
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
@ -120,18 +129,18 @@ class Subformer(EncoderInterface):
|
||||
self.causal = causal
|
||||
|
||||
|
||||
if len(downsampling_factor) == 1:
|
||||
downsampling_factor = downsampling_factor * num_downsamplers
|
||||
assert len(downsampling_factor) == num_downsamplers
|
||||
|
||||
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||
encoders = []
|
||||
downsamplers = []
|
||||
bypass = []
|
||||
|
||||
|
||||
num_encoders = len(encoder_dim)
|
||||
assert num_encoders % 2 == 1
|
||||
downsampling_factor = [ 1 ]
|
||||
while len(downsampling_factor) < num_encoders:
|
||||
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
|
||||
|
||||
for i in range(num_encoders):
|
||||
|
||||
for s in structure:
|
||||
if s == 'S':
|
||||
i = len(encoders)
|
||||
encoder_layer = SubformerEncoderLayer(
|
||||
embed_dim=encoder_dim[i],
|
||||
pos_dim=pos_dim,
|
||||
@ -144,8 +153,6 @@ class Subformer(EncoderInterface):
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
encoder = SubformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
@ -155,9 +162,24 @@ class Subformer(EncoderInterface):
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||
)
|
||||
|
||||
encoders.append(encoder)
|
||||
|
||||
pass
|
||||
elif s =='(':
|
||||
pass
|
||||
else:
|
||||
assert s == ')'
|
||||
|
||||
|
||||
num_encoders = len(encoder_dim)
|
||||
assert num_encoders % 2 == 1
|
||||
downsampling_factor = [ 1 ]
|
||||
while len(downsampling_factor) < num_encoders:
|
||||
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
|
||||
|
||||
for i in range(num_encoders):
|
||||
|
||||
|
||||
mid = len(encoders) // 2
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
[ encoders[mid] ],
|
||||
@ -567,13 +589,13 @@ class SubformerEncoder(nn.Module):
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float,
|
||||
chunk_size: int = 256,
|
||||
chunk_sizes: Tuple[int, ...] = (128, 2048),
|
||||
initial_layerdrop_rate: float = 0.5,
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_sizes = chunk_sizes
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
@ -668,13 +690,17 @@ class SubformerEncoder(nn.Module):
|
||||
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
|
||||
"""
|
||||
seq_len = src.shape[0]
|
||||
if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0:
|
||||
return [ seq_len ], [ 0 ] * len(self.layers)
|
||||
else:
|
||||
num_layers = len(self.layers)
|
||||
chunk_indexes = [0, 1] * (num_layers + 1 // 2)
|
||||
return [ self.chunk_size, seq_len ], chunk_indexes[:num_layers]
|
||||
chunk_indexes = []
|
||||
chunk_sizes = []
|
||||
for i, chunk_size in enumerate(self.chunk_sizes):
|
||||
chunk_sizes.append(chunk_size if seq_len % chunk_size == 0
|
||||
else seq_len)
|
||||
|
||||
num_chunk_sizes = len(self.chunk_sizes)
|
||||
for i in range(self.num_layers):
|
||||
chunk_indexes.append(i % num_chunk_sizes)
|
||||
|
||||
return chunk_sizes, chunk_indexes
|
||||
|
||||
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
|
||||
"""
|
||||
@ -809,6 +835,7 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
downsampling_factor: int):
|
||||
assert downsampling_factor > 1
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -864,9 +891,8 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
d = self.downsampling_factor
|
||||
seq_len_reduced = (seq_len + d - 1) // d
|
||||
|
||||
|
||||
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
|
||||
missing = weights_discarded.shape[1] - seq_len_reduced
|
||||
missing = seq_len_reduced - weights_discarded.shape[1]
|
||||
if missing != 0:
|
||||
weights_discarded = torch.cat((weights_discarded,
|
||||
torch.zeros(batch_size, missing,
|
||||
@ -986,6 +1012,12 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
||||
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
# it's possible to get large gradients at this point; clip these at
|
||||
# this point to reduce the extent to which it has to reduce the
|
||||
# grad_scale.
|
||||
weights = clip_grad(weights, 5000.0)
|
||||
|
||||
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
|
||||
batch_size, seq_len_reduced, seq_len))
|
||||
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
|
||||
@ -1849,7 +1881,7 @@ def _test_zipformer_main(causal: bool = False):
|
||||
memory_dim=memory_dim,
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
seq_len = 128
|
||||
# Just make sure the forward pass runs.
|
||||
f = c(
|
||||
torch.randn(seq_len, batch_size, 64),
|
||||
|
Loading…
x
Reference in New Issue
Block a user