add option of using full attention for streaming model decoding (#975)

This commit is contained in:
Zengwei Yao 2023-03-30 14:30:13 +08:00 committed by GitHub
parent bcc5923ab9
commit 2a5a75cb56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 59 additions and 30 deletions

View File

@ -107,6 +107,7 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -138,6 +139,8 @@ from icefall.utils import (
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -288,7 +291,7 @@ def get_parser():
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
)
parser.add_argument(
"--left-context",
@ -370,6 +373,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
if params.decode_chunk_size > 0:
# except the case of using full attention
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,

View File

@ -375,6 +375,11 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
if chunk_size < 0:
# use full attention
chunk_size = x.size(0)
left_context = -1
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0

View File

@ -295,7 +295,7 @@ def get_parser():
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
)
parser.add_argument(
@ -378,12 +378,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.decode_chunk_size > 0:
# except the case of using full attention
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,

View File

@ -344,7 +344,7 @@ def get_parser():
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
)
parser.add_argument(
@ -508,12 +508,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.decode_chunk_size > 0:
# except the case of using full attention
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,

View File

@ -326,14 +326,14 @@ def get_parser():
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)", # noqa
help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa
)
parser.add_argument(
@ -409,12 +409,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.decode_chunk_size > 0:
# except the case of using full attention
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,

View File

@ -291,7 +291,7 @@ def get_parser():
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
)
parser.add_argument(
@ -470,12 +470,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.decode_chunk_size > 0:
# except the case of using full attention
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,

View File

@ -358,6 +358,11 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item()
if chunk_size < 0:
# use full attention
chunk_size = x.size(0)
left_context = -1
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0