add option of using full attention for streaming model decoding (#975)

This commit is contained in:
Zengwei Yao 2023-03-30 14:30:13 +08:00 committed by GitHub
parent bcc5923ab9
commit 2a5a75cb56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 59 additions and 30 deletions

View File

@ -107,6 +107,7 @@ Usage:
import argparse import argparse
import logging import logging
import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -138,6 +139,8 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -288,7 +291,7 @@ def get_parser():
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
) )
parser.add_argument( parser.add_argument(
"--left-context", "--left-context",
@ -370,6 +373,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming: if params.simulate_streaming:
if params.decode_chunk_size > 0:
# except the case of using full attention
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -375,6 +375,11 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
if chunk_size < 0:
# use full attention
chunk_size = x.size(0)
left_context = -1
num_left_chunks = -1 num_left_chunks = -1
if left_context >= 0: if left_context >= 0:
assert left_context % chunk_size == 0 assert left_context % chunk_size == 0

View File

@ -295,7 +295,7 @@ def get_parser():
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
) )
parser.add_argument( parser.add_argument(
@ -378,12 +378,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming: if params.simulate_streaming:
feature_lens += params.left_context if params.decode_chunk_size > 0:
feature = torch.nn.functional.pad( # except the case of using full attention
feature, feature_lens += params.left_context
pad=(0, 0, 0, params.left_context), feature = torch.nn.functional.pad(
value=LOG_EPS, feature,
) pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -344,7 +344,7 @@ def get_parser():
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
) )
parser.add_argument( parser.add_argument(
@ -508,12 +508,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming: if params.simulate_streaming:
feature_lens += params.left_context if params.decode_chunk_size > 0:
feature = torch.nn.functional.pad( # except the case of using full attention
feature, feature_lens += params.left_context
pad=(0, 0, 0, params.left_context), feature = torch.nn.functional.pad(
value=LOG_EPS, feature,
) pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -326,14 +326,14 @@ def get_parser():
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
) )
parser.add_argument( parser.add_argument(
"--left-context", "--left-context",
type=int, type=int,
default=64, default=64,
help="left context can be seen during decoding (in frames after subsampling)", # noqa help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa
) )
parser.add_argument( parser.add_argument(
@ -409,12 +409,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming: if params.simulate_streaming:
feature_lens += params.left_context if params.decode_chunk_size > 0:
feature = torch.nn.functional.pad( # except the case of using full attention
feature, feature_lens += params.left_context
pad=(0, 0, 0, params.left_context), feature = torch.nn.functional.pad(
value=LOG_EPS, feature,
) pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -291,7 +291,7 @@ def get_parser():
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.",
) )
parser.add_argument( parser.add_argument(
@ -470,12 +470,14 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming: if params.simulate_streaming:
feature_lens += params.left_context if params.decode_chunk_size > 0:
feature = torch.nn.functional.pad( # except the case of using full attention
feature, feature_lens += params.left_context
pad=(0, 0, 0, params.left_context), feature = torch.nn.functional.pad(
value=LOG_EPS, feature,
) pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -358,6 +358,11 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
if chunk_size < 0:
# use full attention
chunk_size = x.size(0)
left_context = -1
num_left_chunks = -1 num_left_chunks = -1
if left_context >= 0: if left_context >= 0:
assert left_context % chunk_size == 0 assert left_context % chunk_size == 0