mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
add option of using full attention for streaming model decoding (#975)
This commit is contained in:
parent
bcc5923ab9
commit
2a5a75cb56
@ -107,6 +107,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -138,6 +139,8 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -288,7 +291,7 @@ def get_parser():
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
@ -370,6 +373,14 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
if params.decode_chunk_size > 0:
|
||||
# except the case of using full attention
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -375,6 +375,11 @@ class Conformer(EncoderInterface):
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
if chunk_size < 0:
|
||||
# use full attention
|
||||
chunk_size = x.size(0)
|
||||
left_context = -1
|
||||
|
||||
num_left_chunks = -1
|
||||
if left_context >= 0:
|
||||
assert left_context % chunk_size == 0
|
||||
|
@ -295,7 +295,7 @@ def get_parser():
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -378,12 +378,14 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
if params.decode_chunk_size > 0:
|
||||
# except the case of using full attention
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -344,7 +344,7 @@ def get_parser():
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -508,12 +508,14 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
if params.decode_chunk_size > 0:
|
||||
# except the case of using full attention
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -326,14 +326,14 @@ def get_parser():
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
||||
help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -409,12 +409,14 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
if params.decode_chunk_size > 0:
|
||||
# except the case of using full attention
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -291,7 +291,7 @@ def get_parser():
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -470,12 +470,14 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
if params.decode_chunk_size > 0:
|
||||
# except the case of using full attention
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
|
@ -358,6 +358,11 @@ class Conformer(Transformer):
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
if chunk_size < 0:
|
||||
# use full attention
|
||||
chunk_size = x.size(0)
|
||||
left_context = -1
|
||||
|
||||
num_left_chunks = -1
|
||||
if left_context >= 0:
|
||||
assert left_context % chunk_size == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user