remove tail padding for non-streaming models (#625)

This commit is contained in:
Wei Kang 2022-11-01 11:09:56 +08:00 committed by GitHub
parent 03668771d7
commit d389524d45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 28 deletions

View File

@ -380,14 +380,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context feature_lens += params.left_context
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.left_context), pad=(0, 0, 0, params.left_context),
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -462,14 +462,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context feature_lens += params.left_context
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.left_context), pad=(0, 0, 0, params.left_context),
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -411,14 +411,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context feature_lens += params.left_context
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.left_context), pad=(0, 0, 0, params.left_context),
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,

View File

@ -378,14 +378,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context feature_lens += params.left_context
feature = torch.nn.functional.pad( feature = torch.nn.functional.pad(
feature, feature,
pad=(0, 0, 0, params.left_context), pad=(0, 0, 0, params.left_context),
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,