decoding and computing ppl with nnlm

This commit is contained in:
Guo Liyong 2021-10-08 19:19:15 +08:00
parent a80e58e15d
commit b9d6119932
8 changed files with 5548 additions and 0 deletions

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0
import torch
import numpy as np
import yaml
from utils.nnlm_evaluator import NNLMEvaluator
# An example of computing PPL from transformer language model
with open("conformer_ctc/lm_config.yaml") as f:
lm_args = yaml.safe_load(f)
# TODO(Liyong Guo): make model definition configable
lm_args.pop("model_config")
evaluator = NNLMEvaluator.build_evaluator(**lm_args, device="cuda")
res = evaluator.nll(
"conformer_ctc/data/transcripts/test_clean/text"
)
# ppl on test_clean is 89.71
print(np.mean(res.nlls))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,133 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0
from typing import Any
from typing import List
from typing import Tuple
import math
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformer import generate_square_subsequent_mask
from transformer import make_pad_mask
from transformer import TransformerEncoderLayer
class Encoder(nn.Module):
def __init__(
self,
embed_unit=128,
d_model=512,
nhead=8,
attention_dropout_rate=0.0,
num_encoder_layers=16,
dim_feedforward=2048,
normalize_before=True,
):
super().__init__()
self.input_embed = nn.Sequential(
nn.Linear(embed_unit, d_model),
nn.LayerNorm(d_model),
nn.Dropout(attention_dropout_rate),
nn.ReLU(),
)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
normalize_before=True,
dropout=attention_dropout_rate,
)
self.encoders = nn.TransformerEncoder(
encoder_layer, num_encoder_layers, nn.LayerNorm(d_model)
)
def forward(self, xs, token_lens):
# xs: N S E
xs = self.input_embed(xs)
mask = generate_square_subsequent_mask(xs.shape[1]).to(xs.device)
src_key_padding_mask = make_pad_mask(token_lens).to(xs.device)
# xs: N S E --> S N E
xs = xs.transpose(0, 1)
xs = self.encoders(
xs, mask=mask, src_key_padding_mask=src_key_padding_mask
)
# xs: S N E --> N S E
xs = xs.transpose(0, 1)
return xs
class TransformerLM(nn.Module):
def __init__(
self,
num_encoder_layers: int = 16,
vocab_size: int = 5000,
embed_unit: int = 128,
d_model: int = 512,
nhead: int = 8,
dim_feedforward: int = 2048,
dropout_rate: float = 0.0,
ignore_id: int = 0,
):
super().__init__()
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.ignore_id = ignore_id
self.embed = nn.Embedding(vocab_size, embed_unit)
self.encoder = Encoder(
embed_unit=embed_unit,
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
)
self.decoder = nn.Linear(d_model, vocab_size)
def forward(
self,
src: torch.Tensor,
token_lens,
) -> Tuple[torch.Tensor, None]:
# src: N, S
x = self.embed(src)
h = self.encoder(x, token_lens)
# y: N, S, E
y = self.decoder(h)
return y
def nll(self, xs_pad, target_pad, token_lens):
# xs_pad/target_pad: N, S
# An example element of xs_pad:
# <sos> token token token ... token
#
# An example element of target_pad:
# token token token ... token <eos>
y = self.forward(xs_pad, token_lens)
# nll: (N * S,)
nll = F.cross_entropy(
y.view(-1, y.shape[-1]), target_pad.view(-1), reduction="none"
)
# assign padded postion with 0.0
nll.masked_fill_(make_pad_mask(token_lens).to(nll.device).view(-1), 0.0)
# nll: (N * S,) -> (N, S)
nll = nll.view(xs_pad.size(0), -1)
return nll

View File

@ -898,6 +898,36 @@ def encoder_padding_mask(
lengths[sequence_idx] = start_frame + num_frames lengths[sequence_idx] = start_frame + num_frames
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
return make_pad_mask(lengths, max_len)
def make_pad_mask(lengths: List[int], max_len: Optional[int] = None):
"""Make mask tensor representing padded part.
Args:
lengths: (B,).
max_len: max_len in the batch
Returns:
Tensor: Mask tensor representing padded part.
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[False, False, False, False, False],
[False, False, False, True, True],
[False, False, True, True, True]]
With lengths and max_len.
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths, 6)
masks = [[False, False, False, False, False, True],
[False, False, False, True, True, True],
[False, False, True, True, True, True]]
"""
if max_len is None:
max_len = int(max(lengths))
bs = int(len(lengths)) bs = int(len(lengths))
seq_range = torch.arange(0, max_len, dtype=torch.int64) seq_range = torch.arange(0, max_len, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)

View File

@ -0,0 +1,125 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0
import argparse
import copy
import os
import yaml
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
from utils.text_dataset import (
DatasetOption,
TextFileDataIterator,
TokenidsDataIterator,
AbsLMDataIterator,
)
from utils.numericalizer import Numericalizer
from lm_transformer import TransformerLM
_TYPES_SUPPORTED = ["text_file", "word_id"]
def _validate_input_type(input_type: Optional[str] = None):
# A valid input_type must be assigned from the client
assert input_type is not None
assert input_type in _TYPES_SUPPORTED
@dataclass(frozen=True)
class PPLResult:
nlls: List[float]
ntokens: int
nwords: int
@property
def total_nll(self):
return sum(self.nlls)
@property
def token_ppl(self):
return np.exp(self.total_nll / self.ntokens)
@property
def word_ppl(self):
return np.exp(self.total_nll / self.nwords)
class NNLMEvaluator(object):
@torch.no_grad()
def nll(self, text_source):
nlls = []
total_nll = 0.0
total_ntokens = 0
total_nwords = 0
for xs_pad, target_pad, word_lens, token_lens in self.dataset(
text_source
):
xs_pad = xs_pad.to(self.device)
target_pad = target_pad.to(self.device)
nll = self.lm.nll(xs_pad, target_pad, token_lens)
nll = nll.detach().cpu().numpy().sum(1)
nlls.extend(nll)
total_ntokens += sum(token_lens)
total_nwords += sum(word_lens)
ppl_result = PPLResult(
nlls=nlls, ntokens=total_ntokens, nwords=total_nwords
)
return ppl_result
@dataclass
class NNLMEvaluator(NNLMEvaluator):
lm: TransformerLM
dataset: AbsLMDataIterator
device: Union[str, torch.device]
@classmethod
def build_evaluator(
cls,
lm: str = None,
bpemodel=None,
token_list=None,
device="cpu",
input_type="text_file",
batch_size=32,
numericalizer=None,
src_word_table=None,
):
_validate_input_type(input_type)
assert lm is not None
model = TransformerLM()
state_dict = torch.load(lm)
model.load_state_dict(state_dict)
model.to(device)
if numericalizer is None:
numericalizer = Numericalizer(
tokenizer_file=bpemodel, token_list=token_list
)
dataset_option = DatasetOption(
input_type=input_type,
batch_size=batch_size,
preprocessor=numericalizer,
)
if input_type == "text_file":
dataset = TextFileDataIterator(dataset_option)
elif input_type == "word_id":
dataset = TokenidsDataIterator(
dataset_option,
numericalizer=numericalizer,
src_word_table=src_word_table,
)
evaluator = NNLMEvaluator(lm=model, dataset=dataset, device=device)
return evaluator

View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterable, List, Optional, Union
from pathlib import Path
import numpy as np
import sentencepiece as spm
class PreProcessor(ABC):
@abstractmethod
def __call__(self, text: str) -> List[int]:
raise NotImplementedError
class Numericalizer(PreProcessor):
def __init__(self, tokenizer_file, token_list, unk_symbol="<unk>"):
self.tokenizer_file = tokenizer_file
self.token_list = token_list
self._token2idx = None
self._tokenizer = None
self._assign_special_symbols()
def _assign_special_symbols(self):
# <sos> and <eos> share same index for model download from espnet model zoo
assert "<sos/eos>" in self.token2idx or (
"<sos>" in self.token2idx and "<eos>" in self.tokenid
)
assert "<unk>" in self.token2idx
self.sos_idx = (
self.token2idx["<sos/eos>"]
if "<sos/eos>" in self.token2idx
else self.token2idx["<sos>"]
)
self.eos_idx = (
self.token2idx["<sos/eos>"]
if "<sos/eos>" in self.token2idx
else self.token2idx["<eos>"]
)
self.unk_idx = self.token2idx["<unk>"]
@property
def tokenizer(self):
if self._tokenizer is None:
sp = spm.SentencePieceProcessor()
sp.Load(self.tokenizer_file)
self._tokenizer = sp
return self._tokenizer
def text2tokens(self, line: str) -> List[str]:
return self.tokenizer.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
return self.tokenizer.DecodePieces(list(tokens))
@property
def token2idx(self):
if self._token2idx is None:
self._token2idx = {}
for idx, token in enumerate(self.token_list):
if token in self._token2idx:
raise RuntimeError(f'Symbol "{token}" is duplicated')
self._token2idx[token] = idx
return self._token2idx
def ids2tokens(
self, integers: Union[np.ndarray, Iterable[int]]
) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def __call__(self, text: str) -> List[int]:
tokens = self.text2tokens(text)
token_idxs = (
[self.sos_idx]
+ [self.token2idx.get(token, self.unk_idx) for token in tokens]
+ [self.eos_idx]
)
return token_idxs

View File

@ -0,0 +1,125 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
# Apache 2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import k2
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from utils.numericalizer import PreProcessor
class CollateFunc(object):
"""Collate function for LMDataset"""
def __init__(self, pad_index=None):
# pad_index should be identical to ignore_index of torch.nn.NLLLoss
# and padding_idx in torch.nn.Embedding
self.pad_index = pad_index
def __call__(self, batch: List[List[int]]):
"""
batch is a ragged 2-d array, with a row
represents a tokenized text, whose format is:
<bos_id> token_id token_id token_id *** <eos_id>
"""
# data_pad: [batch_size, max_seq_len]
# max_seq_len == len(max(batch, key=len))
data_pad = pad_sequence(
[torch.from_numpy(np.array(x)).long() for x in batch],
True,
self.pad_index,
)
data_pad = data_pad.contiguous()
xs_pad = data_pad[:, :-1].contiguous()
ys_pad = data_pad[:, 1:].contiguous()
# xs_pad/ys_pad: [batch_size, max_seq_len - 1]
# - 1 for removing <bos> or <eos>
return xs_pad, ys_pad
@dataclass
class DatasetOption:
preprocessor: PreProcessor
input_type: Optional[str] = "text_file"
batch_size: int = 32
pad_value: int = 0
@dataclass
class AbsLMDataIterator(ABC):
preprocessor: PreProcessor
input_type: Optional[str] = "text_file"
batch_size: int = 32
pad_value: int = 0
words_txt: Optional[Path] = None
_collate_fn = None
@property
def collate_fn(self):
if self._collate_fn is None:
self._collate_fn = CollateFunc(self.pad_value)
return self._collate_fn
def _reset_container(self):
self.token_ids_list = []
self.token_lens = []
self.word_lens = []
@abstractmethod
def _text_generator(self, text_source):
raise NotImplementedError
def __call__(self, text_source):
"""
Args:
text_source may be text_file / word_seqs
"""
self._reset_container()
for text in self._text_generator(text_source):
self.word_lens.append(len(text.split()) + 1) # +1 for <eos>
token_ids = self.preprocessor(text)
self.token_ids_list.append(token_ids)
self.token_lens.append(len(token_ids) - 1) # -1 to remove <sos>
if len(self.token_ids_list) == self.batch_size:
xs_pad, ys_pad = self.collate_fn(self.token_ids_list)
yield xs_pad, ys_pad, self.word_lens, self.token_lens
self._reset_container()
if len(self.token_ids_list) != 0:
xs_pad, ys_pad = self.collate_fn(self.token_ids_list)
yield xs_pad, ys_pad, self.word_lens, self.token_lens
self._reset_container()
class TextFileDataIterator(AbsLMDataIterator):
def __init__(self, dataset_option):
super().__init__(**(dataset_option.__dict__))
def _text_generator(self, text_file):
with open(text_file, "r") as f:
for text in f:
text = text.strip().split(maxsplit=1)[1]
yield text
class TokenidsDataIterator(AbsLMDataIterator):
def __init__(self, dataset_option, numericalizer, src_word_table):
super().__init__(**(dataset_option.__dict__))
self.numericalizer = numericalizer
self.src_word_table = src_word_table
def _text_generator(self, token_ids):
for utt in token_ids:
text = " ".join([self.src_word_table[token] for token in utt])
text = text.upper()
yield text

View File

@ -773,6 +773,7 @@ def rescore_with_attention_decoder(
ngram_lm_scale: Optional[float] = None, ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None, attention_scale: Optional[float] = None,
use_double_scores: bool = True, use_double_scores: bool = True,
nnlm_evaluator=None,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses """This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is an attention decoder to rescore them. The path with the highest score is
@ -854,11 +855,21 @@ def rescore_with_attention_decoder(
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
) )
assert nll.ndim == 2 assert nll.ndim == 2
assert nll.shape[0] == len(token_ids) assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1) attention_scores = -nll.sum(dim=1)
if nnlm_evaluator is not None:
aux_labels = k2.RaggedTensor(tokens_shape, nbest.fsa.aux_labels)
aux_labels = aux_labels.remove_values_leq(0)
aux_labels = aux_labels.tolist()
assert len(aux_labels) == len(token_ids)
ppl_result = nnlm_evaluator.nll(aux_labels)
nnlm_scores = -torch.tensor(ppl_result.nlls).to(attention_scores.device)
assert nnlm_scores.shape[0] == len(token_ids)
if ngram_lm_scale is None: if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
@ -881,6 +892,8 @@ def rescore_with_attention_decoder(
+ n_scale * ngram_lm_scores.values + n_scale * ngram_lm_scores.values
+ a_scale * attention_scores + a_scale * attention_scores
) )
if nnlm_evaluator is not None:
tot_scores = tot_scores + nnlm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax() max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes) best_path = k2.index_fsa(nbest.fsa, max_indexes)