add tail padding

This commit is contained in:
yaozengwei 2022-05-31 18:53:20 +08:00
parent 7307f1c6bc
commit 10998bef69

View File

@ -59,6 +59,7 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -70,7 +71,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -91,6 +92,8 @@ from icefall.utils import (
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -262,13 +265,20 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.right_context_length
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.right_context_length),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -294,6 +304,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):