mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change how chunk-size is specified
This commit is contained in:
parent
e7e7560bba
commit
329175c897
@ -116,7 +116,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model, get_chunk_info
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
|||||||
@ -84,8 +84,6 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
chunk_size: int = -1,
|
|
||||||
left_context_chunks: int = -1,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -106,9 +104,6 @@ class Transducer(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
chunk_size, left_context_chunks:
|
|
||||||
For chunkwise causal training; will be passed to the zipformer encoder.
|
|
||||||
chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -124,8 +119,8 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size,
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
left_context_chunks=left_context_chunks)
|
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
|||||||
@ -1117,7 +1117,7 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
|||||||
# to make the convolution causal.
|
# to make the convolution causal.
|
||||||
left_pad = self.kernel_size // 2
|
left_pad = self.kernel_size // 2
|
||||||
|
|
||||||
if chunk_size < 0:
|
if chunk_size < 0 or chunk_size > seq_len:
|
||||||
chunk_size = seq_len
|
chunk_size = seq_len
|
||||||
right_pad = -seq_len % chunk_size
|
right_pad = -seq_len % chunk_size
|
||||||
|
|
||||||
|
|||||||
@ -226,20 +226,29 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chunk-size",
|
"--causal",
|
||||||
type=str,
|
type=str2bool,
|
||||||
default="-1",
|
default=True,
|
||||||
help=" Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="If True, use causal version of model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chunk-left-context-frames",
|
"--chunk-size",
|
||||||
|
type=str,
|
||||||
|
default="16,32,64,-1",
|
||||||
|
help="Chunk sizes will be chosen randomly from this list during training. "
|
||||||
|
" Must be just -1 if --causal=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context-frames",
|
||||||
type=str,
|
type=str,
|
||||||
default="64,128,256,-1",
|
default="64,128,256,-1",
|
||||||
help="Left-contexts for chunkwise training, measured in frames (positive values must be "
|
help="Maximum left-contexts for causal training, measured in frames which will "
|
||||||
"multiples of all positive elements of chunk-size). If --chunk-size is specified, "
|
"be converted to a number of chunks. If splitting into chunks, "
|
||||||
"chunk left-context frames will be chosen randomly from this list."
|
"chunk left-context frames will be chosen randomly from this list; else not relevant."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -544,7 +553,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
causal=(params.chunk_size != "-1"),
|
causal=params.causal,
|
||||||
|
chunk_size=to_int_tuple(params.chunk_size),
|
||||||
|
left_context_frames=to_int_tuple(params.left_context_frames),
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -705,25 +716,6 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
def get_chunk_info(params: AttributeDict) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns chunk_size and left_context_chunks.
|
|
||||||
"""
|
|
||||||
chunk_sizes = list(map(int, params.chunk_size.split(',')))
|
|
||||||
n = len(chunk_sizes)
|
|
||||||
chunk_size = random.choice(chunk_sizes)
|
|
||||||
if chunk_size == -1:
|
|
||||||
left_context_chunks = -1
|
|
||||||
else:
|
|
||||||
chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(',')))
|
|
||||||
m = len(chunk_left_context_frames)
|
|
||||||
left_context_frames = random.choice(chunk_left_context_frames)
|
|
||||||
if left_context_frames != -1:
|
|
||||||
assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value"
|
|
||||||
# Note: in Python, -1 // n == -1 for n > 0
|
|
||||||
left_context_chunks = left_context_frames // chunk_size
|
|
||||||
return chunk_size, left_context_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -770,8 +762,6 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
chunk_size, left_context_chunks = get_chunk_info(params)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -780,8 +770,6 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
chunk_size=chunk_size,
|
|
||||||
left_context_chunks=left_context_chunks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
|
|||||||
@ -102,6 +102,12 @@ class Zipformer(EncoderInterface):
|
|||||||
slightly slower and use more memory. Enables use of the chunk_size and
|
slightly slower and use more memory. Enables use of the chunk_size and
|
||||||
left_context_chunk options in forward(), which simulates streaming
|
left_context_chunk options in forward(), which simulates streaming
|
||||||
decoding.
|
decoding.
|
||||||
|
chunk_size: (list of int): only set this to other than [-1] if causal;
|
||||||
|
the chunk size will be randomly chosen from this list. -1 means no chunking.
|
||||||
|
left_context_frames: (list of int): determines the number of left-
|
||||||
|
context chunks for causal training; will be rounded to a number of
|
||||||
|
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||||
|
rounding and downsampling); an error will be thrown if this is violated.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -122,6 +128,8 @@ class Zipformer(EncoderInterface):
|
|||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
|
chunk_size: Tuple[int] = [-1],
|
||||||
|
left_context_frames: Tuple[int] = [-1],
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer, self).__init__()
|
super(Zipformer, self).__init__()
|
||||||
|
|
||||||
@ -162,6 +170,10 @@ class Zipformer(EncoderInterface):
|
|||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.left_context_frames = left_context_frames
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
@ -319,10 +331,24 @@ class Zipformer(EncoderInterface):
|
|||||||
return feature_masks
|
return feature_masks
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunk_info(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns chunk_size and left_context_chunks.
|
||||||
|
"""
|
||||||
|
if not self.causal:
|
||||||
|
return -1, -1
|
||||||
|
chunk_size = random.choice(self.chunk_size)
|
||||||
|
if chunk_size == -1:
|
||||||
|
left_context_chunks = -1
|
||||||
|
else:
|
||||||
|
left_context_frames = random.choice(self.left_context_frames)
|
||||||
|
# Note: in Python, -1 // n == -1 for n > 0
|
||||||
|
left_context_chunks = left_context_frames // chunk_size
|
||||||
|
return chunk_size, left_context_chunks
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor,
|
self, x: torch.Tensor, x_lens: torch.Tensor,
|
||||||
chunk_size: int = -1,
|
|
||||||
left_context_chunks: int = -1,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -362,6 +388,8 @@ class Zipformer(EncoderInterface):
|
|||||||
outputs = []
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
|
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||||
|
|
||||||
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
@ -2257,6 +2285,8 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
c = Zipformer(
|
c = Zipformer(
|
||||||
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
chunk_size=(4,) if causal else (-1,),
|
||||||
|
left_context_frames=(64,)
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
@ -2264,7 +2294,6 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
chunk_size=4 if causal else -1,
|
|
||||||
)
|
)
|
||||||
f[0].sum().backward()
|
f[0].sum().backward()
|
||||||
c.eval()
|
c.eval()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user