mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Minor fixes
This commit is contained in:
parent
7cc697c03a
commit
e923b1b336
@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
@ -234,7 +235,7 @@ class Conformer(EncoderInterface):
|
||||
|
||||
if not simulate_streaming:
|
||||
assert (
|
||||
decode_states is not None
|
||||
states is not None
|
||||
), "Require cache when sending data in streaming mode"
|
||||
|
||||
assert (
|
||||
@ -423,7 +424,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||
key = torch.cat([states[0, ...], src], dim=0)
|
||||
val = key
|
||||
states[0, ...] = key[-left_context, ...]
|
||||
states[0, ...] = key[-left_context:, ...]
|
||||
else:
|
||||
assert left_context == 0
|
||||
|
||||
@ -441,14 +442,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
residual = src
|
||||
if not self.training and states is not None:
|
||||
src = torch.cat([states[1, ...], src], dim=0)
|
||||
states[1, ...] = src[-left_context, ...]
|
||||
states[1, ...] = src[-left_context:, ...]
|
||||
|
||||
conv = self.conv_module(src)
|
||||
conv = conv[-src.size(0) :, :, :] # noqa: E203
|
||||
conv = conv[-residual.size(0) :, :, :] # noqa: E203
|
||||
|
||||
src = src + self.dropout(conv)
|
||||
src = residual + self.dropout(conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
@ -70,6 +70,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -101,6 +102,7 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -324,6 +326,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
|
Loading…
x
Reference in New Issue
Block a user