mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change how chunk-size is specified
This commit is contained in:
parent
e7e7560bba
commit
329175c897
@ -116,7 +116,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model, get_chunk_info
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
|
||||
@ -84,8 +84,6 @@ class Transducer(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
chunk_size: int = -1,
|
||||
left_context_chunks: int = -1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -106,9 +104,6 @@ class Transducer(nn.Module):
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
chunk_size, left_context_chunks:
|
||||
For chunkwise causal training; will be passed to the zipformer encoder.
|
||||
chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -124,8 +119,8 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size,
|
||||
left_context_chunks=left_context_chunks)
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
|
||||
@ -1117,7 +1117,7 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
# to make the convolution causal.
|
||||
left_pad = self.kernel_size // 2
|
||||
|
||||
if chunk_size < 0:
|
||||
if chunk_size < 0 or chunk_size > seq_len:
|
||||
chunk_size = seq_len
|
||||
right_pad = -seq_len % chunk_size
|
||||
|
||||
|
||||
@ -226,20 +226,29 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-size",
|
||||
type=str,
|
||||
default="-1",
|
||||
help=" Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
"--causal",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="If True, use causal version of model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chunk-left-context-frames",
|
||||
"--chunk-size",
|
||||
type=str,
|
||||
default="16,32,64,-1",
|
||||
help="Chunk sizes will be chosen randomly from this list during training. "
|
||||
" Must be just -1 if --causal=False"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context-frames",
|
||||
type=str,
|
||||
default="64,128,256,-1",
|
||||
help="Left-contexts for chunkwise training, measured in frames (positive values must be "
|
||||
"multiples of all positive elements of chunk-size). If --chunk-size is specified, "
|
||||
"chunk left-context frames will be chosen randomly from this list."
|
||||
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||
"be converted to a number of chunks. If splitting into chunks, "
|
||||
"chunk left-context frames will be chosen randomly from this list; else not relevant."
|
||||
)
|
||||
|
||||
|
||||
@ -544,7 +553,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=(params.chunk_size != "-1"),
|
||||
causal=params.causal,
|
||||
chunk_size=to_int_tuple(params.chunk_size),
|
||||
left_context_frames=to_int_tuple(params.left_context_frames),
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -705,25 +716,6 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def get_chunk_info(params: AttributeDict) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns chunk_size and left_context_chunks.
|
||||
"""
|
||||
chunk_sizes = list(map(int, params.chunk_size.split(',')))
|
||||
n = len(chunk_sizes)
|
||||
chunk_size = random.choice(chunk_sizes)
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(',')))
|
||||
m = len(chunk_left_context_frames)
|
||||
left_context_frames = random.choice(chunk_left_context_frames)
|
||||
if left_context_frames != -1:
|
||||
assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value"
|
||||
# Note: in Python, -1 // n == -1 for n > 0
|
||||
left_context_chunks = left_context_frames // chunk_size
|
||||
return chunk_size, left_context_chunks
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
@ -770,8 +762,6 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
chunk_size, left_context_chunks = get_chunk_info(params)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
@ -780,8 +770,6 @@ def compute_loss(
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
chunk_size=chunk_size,
|
||||
left_context_chunks=left_context_chunks,
|
||||
)
|
||||
|
||||
s = params.simple_loss_scale
|
||||
|
||||
@ -102,6 +102,12 @@ class Zipformer(EncoderInterface):
|
||||
slightly slower and use more memory. Enables use of the chunk_size and
|
||||
left_context_chunk options in forward(), which simulates streaming
|
||||
decoding.
|
||||
chunk_size: (list of int): only set this to other than [-1] if causal;
|
||||
the chunk size will be randomly chosen from this list. -1 means no chunking.
|
||||
left_context_frames: (list of int): determines the number of left-
|
||||
context chunks for causal training; will be rounded to a number of
|
||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||
rounding and downsampling); an error will be thrown if this is violated.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -122,6 +128,8 @@ class Zipformer(EncoderInterface):
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
causal: bool = False,
|
||||
chunk_size: Tuple[int] = [-1],
|
||||
left_context_frames: Tuple[int] = [-1],
|
||||
) -> None:
|
||||
super(Zipformer, self).__init__()
|
||||
|
||||
@ -162,6 +170,10 @@ class Zipformer(EncoderInterface):
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
|
||||
self.causal = causal
|
||||
self.chunk_size = chunk_size
|
||||
self.left_context_frames = left_context_frames
|
||||
|
||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||
assert u <= d
|
||||
|
||||
@ -319,10 +331,24 @@ class Zipformer(EncoderInterface):
|
||||
return feature_masks
|
||||
|
||||
|
||||
def get_chunk_info(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns chunk_size and left_context_chunks.
|
||||
"""
|
||||
if not self.causal:
|
||||
return -1, -1
|
||||
chunk_size = random.choice(self.chunk_size)
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
left_context_frames = random.choice(self.left_context_frames)
|
||||
# Note: in Python, -1 // n == -1 for n > 0
|
||||
left_context_chunks = left_context_frames // chunk_size
|
||||
return chunk_size, left_context_chunks
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor,
|
||||
chunk_size: int = -1,
|
||||
left_context_chunks: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -362,6 +388,8 @@ class Zipformer(EncoderInterface):
|
||||
outputs = []
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||
|
||||
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
@ -2257,6 +2285,8 @@ def _test_zipformer_main(causal: bool = False):
|
||||
c = Zipformer(
|
||||
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
||||
causal=causal,
|
||||
chunk_size=(4,) if causal else (-1,),
|
||||
left_context_frames=(64,)
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
@ -2264,7 +2294,6 @@ def _test_zipformer_main(causal: bool = False):
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
chunk_size=4 if causal else -1,
|
||||
)
|
||||
f[0].sum().backward()
|
||||
c.eval()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user