Minor fixes

This commit is contained in:
pkufool 2022-05-26 10:00:18 +08:00
parent 7cc697c03a
commit e923b1b336
2 changed files with 16 additions and 5 deletions

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import copy
import math
import warnings
@ -234,7 +235,7 @@ class Conformer(EncoderInterface):
if not simulate_streaming:
assert (
decode_states is not None
states is not None
), "Require cache when sending data in streaming mode"
assert (
@ -423,7 +424,7 @@ class ConformerEncoderLayer(nn.Module):
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0, ...], src], dim=0)
val = key
states[0, ...] = key[-left_context, ...]
states[0, ...] = key[-left_context:, ...]
else:
assert left_context == 0
@ -441,14 +442,15 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
residual = src
if not self.training and states is not None:
src = torch.cat([states[1, ...], src], dim=0)
states[1, ...] = src[-left_context, ...]
states[1, ...] = src[-left_context:, ...]
conv = self.conv_module(src)
conv = conv[-src.size(0) :, :, :] # noqa: E203
conv = conv[-residual.size(0) :, :, :] # noqa: E203
src = src + self.dropout(conv)
src = residual + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))

View File

@ -70,6 +70,7 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -101,6 +102,7 @@ from icefall.utils import (
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -324,6 +326,13 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,