Minor fixes

This commit is contained in:
pkufool 2022-05-26 10:00:18 +08:00
parent 7cc697c03a
commit e923b1b336
2 changed files with 16 additions and 5 deletions

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import copy import copy
import math import math
import warnings import warnings
@ -234,7 +235,7 @@ class Conformer(EncoderInterface):
if not simulate_streaming: if not simulate_streaming:
assert ( assert (
decode_states is not None states is not None
), "Require cache when sending data in streaming mode" ), "Require cache when sending data in streaming mode"
assert ( assert (
@ -423,7 +424,7 @@ class ConformerEncoderLayer(nn.Module):
# src: [chunk_size, N, F] e.g. [8, 41, 512] # src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0, ...], src], dim=0) key = torch.cat([states[0, ...], src], dim=0)
val = key val = key
states[0, ...] = key[-left_context, ...] states[0, ...] = key[-left_context:, ...]
else: else:
assert left_context == 0 assert left_context == 0
@ -441,14 +442,15 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
residual = src
if not self.training and states is not None: if not self.training and states is not None:
src = torch.cat([states[1, ...], src], dim=0) src = torch.cat([states[1, ...], src], dim=0)
states[1, ...] = src[-left_context, ...] states[1, ...] = src[-left_context:, ...]
conv = self.conv_module(src) conv = self.conv_module(src)
conv = conv[-src.size(0) :, :, :] # noqa: E203 conv = conv[-residual.size(0) :, :, :] # noqa: E203
src = src + self.dropout(conv) src = residual + self.dropout(conv)
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) src = src + self.dropout(self.feed_forward(src))

View File

@ -70,6 +70,7 @@ Usage:
import argparse import argparse
import logging import logging
import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -101,6 +102,7 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -324,6 +326,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming: if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,