mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
remove tail padding for non-streaming models (#625)
This commit is contained in:
parent
03668771d7
commit
d389524d45
@ -380,14 +380,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
pad=(0, 0, 0, params.left_context),
|
pad=(0, 0, 0, params.left_context),
|
||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -462,14 +462,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
pad=(0, 0, 0, params.left_context),
|
pad=(0, 0, 0, params.left_context),
|
||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -411,14 +411,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
pad=(0, 0, 0, params.left_context),
|
pad=(0, 0, 0, params.left_context),
|
||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
@ -378,14 +378,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
feature_lens += params.left_context
|
feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
feature = torch.nn.functional.pad(
|
||||||
feature,
|
feature,
|
||||||
pad=(0, 0, 0, params.left_context),
|
pad=(0, 0, 0, params.left_context),
|
||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user