mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
Minor fixes
This commit is contained in:
parent
7cc697c03a
commit
e923b1b336
@ -15,6 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -234,7 +235,7 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
if not simulate_streaming:
|
if not simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
decode_states is not None
|
states is not None
|
||||||
), "Require cache when sending data in streaming mode"
|
), "Require cache when sending data in streaming mode"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -423,7 +424,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
key = torch.cat([states[0, ...], src], dim=0)
|
key = torch.cat([states[0, ...], src], dim=0)
|
||||||
val = key
|
val = key
|
||||||
states[0, ...] = key[-left_context, ...]
|
states[0, ...] = key[-left_context:, ...]
|
||||||
else:
|
else:
|
||||||
assert left_context == 0
|
assert left_context == 0
|
||||||
|
|
||||||
@ -441,14 +442,15 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
|
residual = src
|
||||||
if not self.training and states is not None:
|
if not self.training and states is not None:
|
||||||
src = torch.cat([states[1, ...], src], dim=0)
|
src = torch.cat([states[1, ...], src], dim=0)
|
||||||
states[1, ...] = src[-left_context, ...]
|
states[1, ...] = src[-left_context:, ...]
|
||||||
|
|
||||||
conv = self.conv_module(src)
|
conv = self.conv_module(src)
|
||||||
conv = conv[-src.size(0) :, :, :] # noqa: E203
|
conv = conv[-residual.size(0) :, :, :] # noqa: E203
|
||||||
|
|
||||||
src = src + self.dropout(conv)
|
src = residual + self.dropout(conv)
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
@ -70,6 +70,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -101,6 +102,7 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -324,6 +326,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
feature_lens += params.left_context
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user