mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
decoding and computing ppl with nnlm
This commit is contained in:
parent
a80e58e15d
commit
b9d6119932
23
egs/librispeech/ASR/conformer_ctc/compute_ppl.py
Normal file
23
egs/librispeech/ASR/conformer_ctc/compute_ppl.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
from utils.nnlm_evaluator import NNLMEvaluator
|
||||||
|
|
||||||
|
# An example of computing PPL from transformer language model
|
||||||
|
with open("conformer_ctc/lm_config.yaml") as f:
|
||||||
|
lm_args = yaml.safe_load(f)
|
||||||
|
# TODO(Liyong Guo): make model definition configable
|
||||||
|
lm_args.pop("model_config")
|
||||||
|
|
||||||
|
evaluator = NNLMEvaluator.build_evaluator(**lm_args, device="cuda")
|
||||||
|
|
||||||
|
res = evaluator.nll(
|
||||||
|
"conformer_ctc/data/transcripts/test_clean/text"
|
||||||
|
)
|
||||||
|
# ppl on test_clean is 89.71
|
||||||
|
print(np.mean(res.nlls))
|
5012
egs/librispeech/ASR/conformer_ctc/lm_config.yaml
Normal file
5012
egs/librispeech/ASR/conformer_ctc/lm_config.yaml
Normal file
File diff suppressed because it is too large
Load Diff
133
egs/librispeech/ASR/conformer_ctc/lm_transformer.py
Normal file
133
egs/librispeech/ASR/conformer_ctc/lm_transformer.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transformer import generate_square_subsequent_mask
|
||||||
|
from transformer import make_pad_mask
|
||||||
|
from transformer import TransformerEncoderLayer
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_unit=128,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
num_encoder_layers=16,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
normalize_before=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_embed = nn.Sequential(
|
||||||
|
nn.Linear(embed_unit, d_model),
|
||||||
|
nn.LayerNorm(d_model),
|
||||||
|
nn.Dropout(attention_dropout_rate),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
normalize_before=True,
|
||||||
|
dropout=attention_dropout_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoders = nn.TransformerEncoder(
|
||||||
|
encoder_layer, num_encoder_layers, nn.LayerNorm(d_model)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, xs, token_lens):
|
||||||
|
# xs: N S E
|
||||||
|
xs = self.input_embed(xs)
|
||||||
|
mask = generate_square_subsequent_mask(xs.shape[1]).to(xs.device)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(token_lens).to(xs.device)
|
||||||
|
|
||||||
|
# xs: N S E --> S N E
|
||||||
|
xs = xs.transpose(0, 1)
|
||||||
|
xs = self.encoders(
|
||||||
|
xs, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
|
)
|
||||||
|
# xs: S N E --> N S E
|
||||||
|
xs = xs.transpose(0, 1)
|
||||||
|
|
||||||
|
return xs
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_encoder_layers: int = 16,
|
||||||
|
vocab_size: int = 5000,
|
||||||
|
embed_unit: int = 128,
|
||||||
|
d_model: int = 512,
|
||||||
|
nhead: int = 8,
|
||||||
|
dim_feedforward: int = 2048,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
ignore_id: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.sos = vocab_size - 1
|
||||||
|
self.eos = vocab_size - 1
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
|
||||||
|
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
embed_unit=embed_unit,
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
num_encoder_layers=num_encoder_layers,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = nn.Linear(d_model, vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src: torch.Tensor,
|
||||||
|
token_lens,
|
||||||
|
) -> Tuple[torch.Tensor, None]:
|
||||||
|
# src: N, S
|
||||||
|
x = self.embed(src)
|
||||||
|
h = self.encoder(x, token_lens)
|
||||||
|
# y: N, S, E
|
||||||
|
y = self.decoder(h)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def nll(self, xs_pad, target_pad, token_lens):
|
||||||
|
# xs_pad/target_pad: N, S
|
||||||
|
# An example element of xs_pad:
|
||||||
|
# <sos> token token token ... token
|
||||||
|
#
|
||||||
|
# An example element of target_pad:
|
||||||
|
# token token token ... token <eos>
|
||||||
|
|
||||||
|
y = self.forward(xs_pad, token_lens)
|
||||||
|
|
||||||
|
# nll: (N * S,)
|
||||||
|
nll = F.cross_entropy(
|
||||||
|
y.view(-1, y.shape[-1]), target_pad.view(-1), reduction="none"
|
||||||
|
)
|
||||||
|
|
||||||
|
# assign padded postion with 0.0
|
||||||
|
nll.masked_fill_(make_pad_mask(token_lens).to(nll.device).view(-1), 0.0)
|
||||||
|
|
||||||
|
# nll: (N * S,) -> (N, S)
|
||||||
|
nll = nll.view(xs_pad.size(0), -1)
|
||||||
|
return nll
|
@ -898,6 +898,36 @@ def encoder_padding_mask(
|
|||||||
lengths[sequence_idx] = start_frame + num_frames
|
lengths[sequence_idx] = start_frame + num_frames
|
||||||
|
|
||||||
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
|
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
|
||||||
|
return make_pad_mask(lengths, max_len)
|
||||||
|
|
||||||
|
def make_pad_mask(lengths: List[int], max_len: Optional[int] = None):
|
||||||
|
"""Make mask tensor representing padded part.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lengths: (B,).
|
||||||
|
max_len: max_len in the batch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Mask tensor representing padded part.
|
||||||
|
Examples:
|
||||||
|
With only lengths.
|
||||||
|
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> make_pad_mask(lengths)
|
||||||
|
masks = [[False, False, False, False, False],
|
||||||
|
[False, False, False, True, True],
|
||||||
|
[False, False, True, True, True]]
|
||||||
|
|
||||||
|
With lengths and max_len.
|
||||||
|
>>> lengths = [5, 3, 2]
|
||||||
|
>>> make_pad_mask(lengths, 6)
|
||||||
|
masks = [[False, False, False, False, False, True],
|
||||||
|
[False, False, False, True, True, True],
|
||||||
|
[False, False, True, True, True, True]]
|
||||||
|
"""
|
||||||
|
if max_len is None:
|
||||||
|
max_len = int(max(lengths))
|
||||||
|
|
||||||
bs = int(len(lengths))
|
bs = int(len(lengths))
|
||||||
seq_range = torch.arange(0, max_len, dtype=torch.int64)
|
seq_range = torch.arange(0, max_len, dtype=torch.int64)
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||||
|
125
egs/librispeech/ASR/conformer_ctc/utils/nnlm_evaluator.py
Normal file
125
egs/librispeech/ASR/conformer_ctc/utils/nnlm_evaluator.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from utils.text_dataset import (
|
||||||
|
DatasetOption,
|
||||||
|
TextFileDataIterator,
|
||||||
|
TokenidsDataIterator,
|
||||||
|
AbsLMDataIterator,
|
||||||
|
)
|
||||||
|
from utils.numericalizer import Numericalizer
|
||||||
|
from lm_transformer import TransformerLM
|
||||||
|
|
||||||
|
_TYPES_SUPPORTED = ["text_file", "word_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_input_type(input_type: Optional[str] = None):
|
||||||
|
# A valid input_type must be assigned from the client
|
||||||
|
assert input_type is not None
|
||||||
|
assert input_type in _TYPES_SUPPORTED
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PPLResult:
|
||||||
|
nlls: List[float]
|
||||||
|
ntokens: int
|
||||||
|
nwords: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_nll(self):
|
||||||
|
return sum(self.nlls)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_ppl(self):
|
||||||
|
return np.exp(self.total_nll / self.ntokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_ppl(self):
|
||||||
|
return np.exp(self.total_nll / self.nwords)
|
||||||
|
|
||||||
|
|
||||||
|
class NNLMEvaluator(object):
|
||||||
|
@torch.no_grad()
|
||||||
|
def nll(self, text_source):
|
||||||
|
nlls = []
|
||||||
|
total_nll = 0.0
|
||||||
|
total_ntokens = 0
|
||||||
|
total_nwords = 0
|
||||||
|
for xs_pad, target_pad, word_lens, token_lens in self.dataset(
|
||||||
|
text_source
|
||||||
|
):
|
||||||
|
xs_pad = xs_pad.to(self.device)
|
||||||
|
target_pad = target_pad.to(self.device)
|
||||||
|
|
||||||
|
nll = self.lm.nll(xs_pad, target_pad, token_lens)
|
||||||
|
nll = nll.detach().cpu().numpy().sum(1)
|
||||||
|
nlls.extend(nll)
|
||||||
|
total_ntokens += sum(token_lens)
|
||||||
|
total_nwords += sum(word_lens)
|
||||||
|
ppl_result = PPLResult(
|
||||||
|
nlls=nlls, ntokens=total_ntokens, nwords=total_nwords
|
||||||
|
)
|
||||||
|
return ppl_result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NNLMEvaluator(NNLMEvaluator):
|
||||||
|
lm: TransformerLM
|
||||||
|
dataset: AbsLMDataIterator
|
||||||
|
device: Union[str, torch.device]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_evaluator(
|
||||||
|
cls,
|
||||||
|
lm: str = None,
|
||||||
|
bpemodel=None,
|
||||||
|
token_list=None,
|
||||||
|
device="cpu",
|
||||||
|
input_type="text_file",
|
||||||
|
batch_size=32,
|
||||||
|
numericalizer=None,
|
||||||
|
src_word_table=None,
|
||||||
|
):
|
||||||
|
_validate_input_type(input_type)
|
||||||
|
assert lm is not None
|
||||||
|
|
||||||
|
model = TransformerLM()
|
||||||
|
state_dict = torch.load(lm)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if numericalizer is None:
|
||||||
|
numericalizer = Numericalizer(
|
||||||
|
tokenizer_file=bpemodel, token_list=token_list
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_option = DatasetOption(
|
||||||
|
input_type=input_type,
|
||||||
|
batch_size=batch_size,
|
||||||
|
preprocessor=numericalizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_type == "text_file":
|
||||||
|
dataset = TextFileDataIterator(dataset_option)
|
||||||
|
elif input_type == "word_id":
|
||||||
|
dataset = TokenidsDataIterator(
|
||||||
|
dataset_option,
|
||||||
|
numericalizer=numericalizer,
|
||||||
|
src_word_table=src_word_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = NNLMEvaluator(lm=model, dataset=dataset, device=device)
|
||||||
|
return evaluator
|
87
egs/librispeech/ASR/conformer_ctc/utils/numericalizer.py
Normal file
87
egs/librispeech/ASR/conformer_ctc/utils/numericalizer.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, List, Optional, Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
|
||||||
|
class PreProcessor(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, text: str) -> List[int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Numericalizer(PreProcessor):
|
||||||
|
def __init__(self, tokenizer_file, token_list, unk_symbol="<unk>"):
|
||||||
|
self.tokenizer_file = tokenizer_file
|
||||||
|
self.token_list = token_list
|
||||||
|
self._token2idx = None
|
||||||
|
self._tokenizer = None
|
||||||
|
self._assign_special_symbols()
|
||||||
|
|
||||||
|
def _assign_special_symbols(self):
|
||||||
|
# <sos> and <eos> share same index for model download from espnet model zoo
|
||||||
|
assert "<sos/eos>" in self.token2idx or (
|
||||||
|
"<sos>" in self.token2idx and "<eos>" in self.tokenid
|
||||||
|
)
|
||||||
|
assert "<unk>" in self.token2idx
|
||||||
|
self.sos_idx = (
|
||||||
|
self.token2idx["<sos/eos>"]
|
||||||
|
if "<sos/eos>" in self.token2idx
|
||||||
|
else self.token2idx["<sos>"]
|
||||||
|
)
|
||||||
|
self.eos_idx = (
|
||||||
|
self.token2idx["<sos/eos>"]
|
||||||
|
if "<sos/eos>" in self.token2idx
|
||||||
|
else self.token2idx["<eos>"]
|
||||||
|
)
|
||||||
|
self.unk_idx = self.token2idx["<unk>"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self):
|
||||||
|
if self._tokenizer is None:
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.Load(self.tokenizer_file)
|
||||||
|
self._tokenizer = sp
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
def text2tokens(self, line: str) -> List[str]:
|
||||||
|
return self.tokenizer.EncodeAsPieces(line)
|
||||||
|
|
||||||
|
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||||
|
return self.tokenizer.DecodePieces(list(tokens))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token2idx(self):
|
||||||
|
if self._token2idx is None:
|
||||||
|
self._token2idx = {}
|
||||||
|
for idx, token in enumerate(self.token_list):
|
||||||
|
if token in self._token2idx:
|
||||||
|
raise RuntimeError(f'Symbol "{token}" is duplicated')
|
||||||
|
self._token2idx[token] = idx
|
||||||
|
|
||||||
|
return self._token2idx
|
||||||
|
|
||||||
|
def ids2tokens(
|
||||||
|
self, integers: Union[np.ndarray, Iterable[int]]
|
||||||
|
) -> List[str]:
|
||||||
|
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||||
|
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||||
|
return [self.token_list[i] for i in integers]
|
||||||
|
|
||||||
|
def __call__(self, text: str) -> List[int]:
|
||||||
|
tokens = self.text2tokens(text)
|
||||||
|
token_idxs = (
|
||||||
|
[self.sos_idx]
|
||||||
|
+ [self.token2idx.get(token, self.unk_idx) for token in tokens]
|
||||||
|
+ [self.eos_idx]
|
||||||
|
)
|
||||||
|
return token_idxs
|
125
egs/librispeech/ASR/conformer_ctc/utils/text_dataset.py
Normal file
125
egs/librispeech/ASR/conformer_ctc/utils/text_dataset.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||||
|
# Apache 2.0
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from utils.numericalizer import PreProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class CollateFunc(object):
|
||||||
|
"""Collate function for LMDataset"""
|
||||||
|
|
||||||
|
def __init__(self, pad_index=None):
|
||||||
|
# pad_index should be identical to ignore_index of torch.nn.NLLLoss
|
||||||
|
# and padding_idx in torch.nn.Embedding
|
||||||
|
self.pad_index = pad_index
|
||||||
|
|
||||||
|
def __call__(self, batch: List[List[int]]):
|
||||||
|
"""
|
||||||
|
batch is a ragged 2-d array, with a row
|
||||||
|
represents a tokenized text, whose format is:
|
||||||
|
<bos_id> token_id token_id token_id *** <eos_id>
|
||||||
|
"""
|
||||||
|
# data_pad: [batch_size, max_seq_len]
|
||||||
|
# max_seq_len == len(max(batch, key=len))
|
||||||
|
data_pad = pad_sequence(
|
||||||
|
[torch.from_numpy(np.array(x)).long() for x in batch],
|
||||||
|
True,
|
||||||
|
self.pad_index,
|
||||||
|
)
|
||||||
|
data_pad = data_pad.contiguous()
|
||||||
|
xs_pad = data_pad[:, :-1].contiguous()
|
||||||
|
ys_pad = data_pad[:, 1:].contiguous()
|
||||||
|
# xs_pad/ys_pad: [batch_size, max_seq_len - 1]
|
||||||
|
# - 1 for removing <bos> or <eos>
|
||||||
|
return xs_pad, ys_pad
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetOption:
|
||||||
|
preprocessor: PreProcessor
|
||||||
|
input_type: Optional[str] = "text_file"
|
||||||
|
batch_size: int = 32
|
||||||
|
pad_value: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AbsLMDataIterator(ABC):
|
||||||
|
preprocessor: PreProcessor
|
||||||
|
input_type: Optional[str] = "text_file"
|
||||||
|
batch_size: int = 32
|
||||||
|
pad_value: int = 0
|
||||||
|
words_txt: Optional[Path] = None
|
||||||
|
_collate_fn = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collate_fn(self):
|
||||||
|
if self._collate_fn is None:
|
||||||
|
self._collate_fn = CollateFunc(self.pad_value)
|
||||||
|
return self._collate_fn
|
||||||
|
|
||||||
|
def _reset_container(self):
|
||||||
|
self.token_ids_list = []
|
||||||
|
self.token_lens = []
|
||||||
|
self.word_lens = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _text_generator(self, text_source):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, text_source):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text_source may be text_file / word_seqs
|
||||||
|
"""
|
||||||
|
self._reset_container()
|
||||||
|
for text in self._text_generator(text_source):
|
||||||
|
self.word_lens.append(len(text.split()) + 1) # +1 for <eos>
|
||||||
|
|
||||||
|
token_ids = self.preprocessor(text)
|
||||||
|
self.token_ids_list.append(token_ids)
|
||||||
|
self.token_lens.append(len(token_ids) - 1) # -1 to remove <sos>
|
||||||
|
|
||||||
|
if len(self.token_ids_list) == self.batch_size:
|
||||||
|
xs_pad, ys_pad = self.collate_fn(self.token_ids_list)
|
||||||
|
|
||||||
|
yield xs_pad, ys_pad, self.word_lens, self.token_lens
|
||||||
|
self._reset_container()
|
||||||
|
|
||||||
|
if len(self.token_ids_list) != 0:
|
||||||
|
xs_pad, ys_pad = self.collate_fn(self.token_ids_list)
|
||||||
|
yield xs_pad, ys_pad, self.word_lens, self.token_lens
|
||||||
|
self._reset_container()
|
||||||
|
|
||||||
|
|
||||||
|
class TextFileDataIterator(AbsLMDataIterator):
|
||||||
|
def __init__(self, dataset_option):
|
||||||
|
super().__init__(**(dataset_option.__dict__))
|
||||||
|
|
||||||
|
def _text_generator(self, text_file):
|
||||||
|
with open(text_file, "r") as f:
|
||||||
|
for text in f:
|
||||||
|
text = text.strip().split(maxsplit=1)[1]
|
||||||
|
yield text
|
||||||
|
|
||||||
|
|
||||||
|
class TokenidsDataIterator(AbsLMDataIterator):
|
||||||
|
def __init__(self, dataset_option, numericalizer, src_word_table):
|
||||||
|
super().__init__(**(dataset_option.__dict__))
|
||||||
|
self.numericalizer = numericalizer
|
||||||
|
self.src_word_table = src_word_table
|
||||||
|
|
||||||
|
def _text_generator(self, token_ids):
|
||||||
|
for utt in token_ids:
|
||||||
|
text = " ".join([self.src_word_table[token] for token in utt])
|
||||||
|
text = text.upper()
|
||||||
|
yield text
|
@ -773,6 +773,7 @@ def rescore_with_attention_decoder(
|
|||||||
ngram_lm_scale: Optional[float] = None,
|
ngram_lm_scale: Optional[float] = None,
|
||||||
attention_scale: Optional[float] = None,
|
attention_scale: Optional[float] = None,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
|
nnlm_evaluator=None,
|
||||||
) -> Dict[str, k2.Fsa]:
|
) -> Dict[str, k2.Fsa]:
|
||||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||||
an attention decoder to rescore them. The path with the highest score is
|
an attention decoder to rescore them. The path with the highest score is
|
||||||
@ -854,11 +855,21 @@ def rescore_with_attention_decoder(
|
|||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert nll.ndim == 2
|
assert nll.ndim == 2
|
||||||
assert nll.shape[0] == len(token_ids)
|
assert nll.shape[0] == len(token_ids)
|
||||||
|
|
||||||
attention_scores = -nll.sum(dim=1)
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
|
||||||
|
if nnlm_evaluator is not None:
|
||||||
|
aux_labels = k2.RaggedTensor(tokens_shape, nbest.fsa.aux_labels)
|
||||||
|
aux_labels = aux_labels.remove_values_leq(0)
|
||||||
|
aux_labels = aux_labels.tolist()
|
||||||
|
assert len(aux_labels) == len(token_ids)
|
||||||
|
ppl_result = nnlm_evaluator.nll(aux_labels)
|
||||||
|
nnlm_scores = -torch.tensor(ppl_result.nlls).to(attention_scores.device)
|
||||||
|
assert nnlm_scores.shape[0] == len(token_ids)
|
||||||
|
|
||||||
if ngram_lm_scale is None:
|
if ngram_lm_scale is None:
|
||||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
@ -881,6 +892,8 @@ def rescore_with_attention_decoder(
|
|||||||
+ n_scale * ngram_lm_scores.values
|
+ n_scale * ngram_lm_scores.values
|
||||||
+ a_scale * attention_scores
|
+ a_scale * attention_scores
|
||||||
)
|
)
|
||||||
|
if nnlm_evaluator is not None:
|
||||||
|
tot_scores = tot_scores + nnlm_scores
|
||||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
max_indexes = ragged_tot_scores.argmax()
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user