mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
add option of using full attention for streaming model decoding (#975)
This commit is contained in:
parent
bcc5923ab9
commit
2a5a75cb56
@ -107,6 +107,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -138,6 +139,8 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -288,7 +291,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--left-context",
|
"--left-context",
|
||||||
@ -370,6 +373,14 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
if params.decode_chunk_size > 0:
|
||||||
|
# except the case of using full attention
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -375,6 +375,11 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
# use full attention
|
||||||
|
chunk_size = x.size(0)
|
||||||
|
left_context = -1
|
||||||
|
|
||||||
num_left_chunks = -1
|
num_left_chunks = -1
|
||||||
if left_context >= 0:
|
if left_context >= 0:
|
||||||
assert left_context % chunk_size == 0
|
assert left_context % chunk_size == 0
|
||||||
|
@ -295,7 +295,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -378,6 +378,8 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
if params.decode_chunk_size > 0:
|
||||||
|
# except the case of using full attention
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
|
@ -344,7 +344,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -508,6 +508,8 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
if params.decode_chunk_size > 0:
|
||||||
|
# except the case of using full attention
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
|
@ -326,14 +326,14 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--left-context",
|
"--left-context",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="left context can be seen during decoding (in frames after subsampling)", # noqa
|
help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -409,6 +409,8 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
if params.decode_chunk_size > 0:
|
||||||
|
# except the case of using full attention
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
|
@ -291,7 +291,7 @@ def get_parser():
|
|||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -470,6 +470,8 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
|
if params.decode_chunk_size > 0:
|
||||||
|
# except the case of using full attention
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
|
@ -358,6 +358,11 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
# use full attention
|
||||||
|
chunk_size = x.size(0)
|
||||||
|
left_context = -1
|
||||||
|
|
||||||
num_left_chunks = -1
|
num_left_chunks = -1
|
||||||
if left_context >= 0:
|
if left_context >= 0:
|
||||||
assert left_context % chunk_size == 0
|
assert left_context % chunk_size == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user