add tail padding

This commit is contained in:
yaozengwei 2022-05-31 18:53:20 +08:00
parent 7307f1c6bc
commit 10998bef69

View File

@ -59,6 +59,7 @@ Usage:
import argparse import argparse
import logging import logging
import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -70,7 +71,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -91,6 +92,8 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -262,13 +265,20 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.right_context_length
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.right_context_length),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -294,6 +304,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):