mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
decoding and computing ppl with nnlm
This commit is contained in:
parent
a80e58e15d
commit
b9d6119932
23
egs/librispeech/ASR/conformer_ctc/compute_ppl.py
Normal file
23
egs/librispeech/ASR/conformer_ctc/compute_ppl.py
Normal file
@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||
# Apache 2.0
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import yaml
|
||||
from utils.nnlm_evaluator import NNLMEvaluator
|
||||
|
||||
# An example of computing PPL from transformer language model
|
||||
with open("conformer_ctc/lm_config.yaml") as f:
|
||||
lm_args = yaml.safe_load(f)
|
||||
# TODO(Liyong Guo): make model definition configable
|
||||
lm_args.pop("model_config")
|
||||
|
||||
evaluator = NNLMEvaluator.build_evaluator(**lm_args, device="cuda")
|
||||
|
||||
res = evaluator.nll(
|
||||
"conformer_ctc/data/transcripts/test_clean/text"
|
||||
)
|
||||
# ppl on test_clean is 89.71
|
||||
print(np.mean(res.nlls))
|
5012
egs/librispeech/ASR/conformer_ctc/lm_config.yaml
Normal file
5012
egs/librispeech/ASR/conformer_ctc/lm_config.yaml
Normal file
File diff suppressed because it is too large
Load Diff
133
egs/librispeech/ASR/conformer_ctc/lm_transformer.py
Normal file
133
egs/librispeech/ASR/conformer_ctc/lm_transformer.py
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||
# Apache 2.0
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import math
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformer import generate_square_subsequent_mask
|
||||
from transformer import make_pad_mask
|
||||
from transformer import TransformerEncoderLayer
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_unit=128,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
attention_dropout_rate=0.0,
|
||||
num_encoder_layers=16,
|
||||
dim_feedforward=2048,
|
||||
normalize_before=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_embed = nn.Sequential(
|
||||
nn.Linear(embed_unit, d_model),
|
||||
nn.LayerNorm(d_model),
|
||||
nn.Dropout(attention_dropout_rate),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
normalize_before=True,
|
||||
dropout=attention_dropout_rate,
|
||||
)
|
||||
|
||||
self.encoders = nn.TransformerEncoder(
|
||||
encoder_layer, num_encoder_layers, nn.LayerNorm(d_model)
|
||||
)
|
||||
|
||||
def forward(self, xs, token_lens):
|
||||
# xs: N S E
|
||||
xs = self.input_embed(xs)
|
||||
mask = generate_square_subsequent_mask(xs.shape[1]).to(xs.device)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(token_lens).to(xs.device)
|
||||
|
||||
# xs: N S E --> S N E
|
||||
xs = xs.transpose(0, 1)
|
||||
xs = self.encoders(
|
||||
xs, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
)
|
||||
# xs: S N E --> N S E
|
||||
xs = xs.transpose(0, 1)
|
||||
|
||||
return xs
|
||||
|
||||
|
||||
class TransformerLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_encoder_layers: int = 16,
|
||||
vocab_size: int = 5000,
|
||||
embed_unit: int = 128,
|
||||
d_model: int = 512,
|
||||
nhead: int = 8,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout_rate: float = 0.0,
|
||||
ignore_id: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, embed_unit)
|
||||
|
||||
self.encoder = Encoder(
|
||||
embed_unit=embed_unit,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
)
|
||||
|
||||
self.decoder = nn.Linear(d_model, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
token_lens,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
# src: N, S
|
||||
x = self.embed(src)
|
||||
h = self.encoder(x, token_lens)
|
||||
# y: N, S, E
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def nll(self, xs_pad, target_pad, token_lens):
|
||||
# xs_pad/target_pad: N, S
|
||||
# An example element of xs_pad:
|
||||
# <sos> token token token ... token
|
||||
#
|
||||
# An example element of target_pad:
|
||||
# token token token ... token <eos>
|
||||
|
||||
y = self.forward(xs_pad, token_lens)
|
||||
|
||||
# nll: (N * S,)
|
||||
nll = F.cross_entropy(
|
||||
y.view(-1, y.shape[-1]), target_pad.view(-1), reduction="none"
|
||||
)
|
||||
|
||||
# assign padded postion with 0.0
|
||||
nll.masked_fill_(make_pad_mask(token_lens).to(nll.device).view(-1), 0.0)
|
||||
|
||||
# nll: (N * S,) -> (N, S)
|
||||
nll = nll.view(xs_pad.size(0), -1)
|
||||
return nll
|
@ -898,6 +898,36 @@ def encoder_padding_mask(
|
||||
lengths[sequence_idx] = start_frame + num_frames
|
||||
|
||||
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
|
||||
return make_pad_mask(lengths, max_len)
|
||||
|
||||
def make_pad_mask(lengths: List[int], max_len: Optional[int] = None):
|
||||
"""Make mask tensor representing padded part.
|
||||
|
||||
Args:
|
||||
lengths: (B,).
|
||||
max_len: max_len in the batch
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor representing padded part.
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[False, False, False, False, False],
|
||||
[False, False, False, True, True],
|
||||
[False, False, True, True, True]]
|
||||
|
||||
With lengths and max_len.
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths, 6)
|
||||
masks = [[False, False, False, False, False, True],
|
||||
[False, False, False, True, True, True],
|
||||
[False, False, True, True, True, True]]
|
||||
"""
|
||||
if max_len is None:
|
||||
max_len = int(max(lengths))
|
||||
|
||||
bs = int(len(lengths))
|
||||
seq_range = torch.arange(0, max_len, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||
|
125
egs/librispeech/ASR/conformer_ctc/utils/nnlm_evaluator.py
Normal file
125
egs/librispeech/ASR/conformer_ctc/utils/nnlm_evaluator.py
Normal file
@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||
# Apache 2.0
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import yaml
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils.text_dataset import (
|
||||
DatasetOption,
|
||||
TextFileDataIterator,
|
||||
TokenidsDataIterator,
|
||||
AbsLMDataIterator,
|
||||
)
|
||||
from utils.numericalizer import Numericalizer
|
||||
from lm_transformer import TransformerLM
|
||||
|
||||
_TYPES_SUPPORTED = ["text_file", "word_id"]
|
||||
|
||||
|
||||
def _validate_input_type(input_type: Optional[str] = None):
|
||||
# A valid input_type must be assigned from the client
|
||||
assert input_type is not None
|
||||
assert input_type in _TYPES_SUPPORTED
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PPLResult:
|
||||
nlls: List[float]
|
||||
ntokens: int
|
||||
nwords: int
|
||||
|
||||
@property
|
||||
def total_nll(self):
|
||||
return sum(self.nlls)
|
||||
|
||||
@property
|
||||
def token_ppl(self):
|
||||
return np.exp(self.total_nll / self.ntokens)
|
||||
|
||||
@property
|
||||
def word_ppl(self):
|
||||
return np.exp(self.total_nll / self.nwords)
|
||||
|
||||
|
||||
class NNLMEvaluator(object):
|
||||
@torch.no_grad()
|
||||
def nll(self, text_source):
|
||||
nlls = []
|
||||
total_nll = 0.0
|
||||
total_ntokens = 0
|
||||
total_nwords = 0
|
||||
for xs_pad, target_pad, word_lens, token_lens in self.dataset(
|
||||
text_source
|
||||
):
|
||||
xs_pad = xs_pad.to(self.device)
|
||||
target_pad = target_pad.to(self.device)
|
||||
|
||||
nll = self.lm.nll(xs_pad, target_pad, token_lens)
|
||||
nll = nll.detach().cpu().numpy().sum(1)
|
||||
nlls.extend(nll)
|
||||
total_ntokens += sum(token_lens)
|
||||
total_nwords += sum(word_lens)
|
||||
ppl_result = PPLResult(
|
||||
nlls=nlls, ntokens=total_ntokens, nwords=total_nwords
|
||||
)
|
||||
return ppl_result
|
||||
|
||||
|
||||
@dataclass
|
||||
class NNLMEvaluator(NNLMEvaluator):
|
||||
lm: TransformerLM
|
||||
dataset: AbsLMDataIterator
|
||||
device: Union[str, torch.device]
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(
|
||||
cls,
|
||||
lm: str = None,
|
||||
bpemodel=None,
|
||||
token_list=None,
|
||||
device="cpu",
|
||||
input_type="text_file",
|
||||
batch_size=32,
|
||||
numericalizer=None,
|
||||
src_word_table=None,
|
||||
):
|
||||
_validate_input_type(input_type)
|
||||
assert lm is not None
|
||||
|
||||
model = TransformerLM()
|
||||
state_dict = torch.load(lm)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
|
||||
if numericalizer is None:
|
||||
numericalizer = Numericalizer(
|
||||
tokenizer_file=bpemodel, token_list=token_list
|
||||
)
|
||||
|
||||
dataset_option = DatasetOption(
|
||||
input_type=input_type,
|
||||
batch_size=batch_size,
|
||||
preprocessor=numericalizer,
|
||||
)
|
||||
|
||||
if input_type == "text_file":
|
||||
dataset = TextFileDataIterator(dataset_option)
|
||||
elif input_type == "word_id":
|
||||
dataset = TokenidsDataIterator(
|
||||
dataset_option,
|
||||
numericalizer=numericalizer,
|
||||
src_word_table=src_word_table,
|
||||
)
|
||||
|
||||
evaluator = NNLMEvaluator(lm=model, dataset=dataset, device=device)
|
||||
return evaluator
|
87
egs/librispeech/ASR/conformer_ctc/utils/numericalizer.py
Normal file
87
egs/librispeech/ASR/conformer_ctc/utils/numericalizer.py
Normal file
@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||
# Apache 2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
|
||||
class PreProcessor(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, text: str) -> List[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Numericalizer(PreProcessor):
|
||||
def __init__(self, tokenizer_file, token_list, unk_symbol="<unk>"):
|
||||
self.tokenizer_file = tokenizer_file
|
||||
self.token_list = token_list
|
||||
self._token2idx = None
|
||||
self._tokenizer = None
|
||||
self._assign_special_symbols()
|
||||
|
||||
def _assign_special_symbols(self):
|
||||
# <sos> and <eos> share same index for model download from espnet model zoo
|
||||
assert "<sos/eos>" in self.token2idx or (
|
||||
"<sos>" in self.token2idx and "<eos>" in self.tokenid
|
||||
)
|
||||
assert "<unk>" in self.token2idx
|
||||
self.sos_idx = (
|
||||
self.token2idx["<sos/eos>"]
|
||||
if "<sos/eos>" in self.token2idx
|
||||
else self.token2idx["<sos>"]
|
||||
)
|
||||
self.eos_idx = (
|
||||
self.token2idx["<sos/eos>"]
|
||||
if "<sos/eos>" in self.token2idx
|
||||
else self.token2idx["<eos>"]
|
||||
)
|
||||
self.unk_idx = self.token2idx["<unk>"]
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
if self._tokenizer is None:
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load(self.tokenizer_file)
|
||||
self._tokenizer = sp
|
||||
return self._tokenizer
|
||||
|
||||
def text2tokens(self, line: str) -> List[str]:
|
||||
return self.tokenizer.EncodeAsPieces(line)
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
return self.tokenizer.DecodePieces(list(tokens))
|
||||
|
||||
@property
|
||||
def token2idx(self):
|
||||
if self._token2idx is None:
|
||||
self._token2idx = {}
|
||||
for idx, token in enumerate(self.token_list):
|
||||
if token in self._token2idx:
|
||||
raise RuntimeError(f'Symbol "{token}" is duplicated')
|
||||
self._token2idx[token] = idx
|
||||
|
||||
return self._token2idx
|
||||
|
||||
def ids2tokens(
|
||||
self, integers: Union[np.ndarray, Iterable[int]]
|
||||
) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def __call__(self, text: str) -> List[int]:
|
||||
tokens = self.text2tokens(text)
|
||||
token_idxs = (
|
||||
[self.sos_idx]
|
||||
+ [self.token2idx.get(token, self.unk_idx) for token in tokens]
|
||||
+ [self.eos_idx]
|
||||
)
|
||||
return token_idxs
|
125
egs/librispeech/ASR/conformer_ctc/utils/text_dataset.py
Normal file
125
egs/librispeech/ASR/conformer_ctc/utils/text_dataset.py
Normal file
@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Guo Liyong)
|
||||
# Apache 2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from utils.numericalizer import PreProcessor
|
||||
|
||||
|
||||
class CollateFunc(object):
|
||||
"""Collate function for LMDataset"""
|
||||
|
||||
def __init__(self, pad_index=None):
|
||||
# pad_index should be identical to ignore_index of torch.nn.NLLLoss
|
||||
# and padding_idx in torch.nn.Embedding
|
||||
self.pad_index = pad_index
|
||||
|
||||
def __call__(self, batch: List[List[int]]):
|
||||
"""
|
||||
batch is a ragged 2-d array, with a row
|
||||
represents a tokenized text, whose format is:
|
||||
<bos_id> token_id token_id token_id *** <eos_id>
|
||||
"""
|
||||
# data_pad: [batch_size, max_seq_len]
|
||||
# max_seq_len == len(max(batch, key=len))
|
||||
data_pad = pad_sequence(
|
||||
[torch.from_numpy(np.array(x)).long() for x in batch],
|
||||
True,
|
||||
self.pad_index,
|
||||
)
|
||||
data_pad = data_pad.contiguous()
|
||||
xs_pad = data_pad[:, :-1].contiguous()
|
||||
ys_pad = data_pad[:, 1:].contiguous()
|
||||
# xs_pad/ys_pad: [batch_size, max_seq_len - 1]
|
||||
# - 1 for removing <bos> or <eos>
|
||||
return xs_pad, ys_pad
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetOption:
|
||||
preprocessor: PreProcessor
|
||||
input_type: Optional[str] = "text_file"
|
||||
batch_size: int = 32
|
||||
pad_value: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsLMDataIterator(ABC):
|
||||
preprocessor: PreProcessor
|
||||
input_type: Optional[str] = "text_file"
|
||||
batch_size: int = 32
|
||||
pad_value: int = 0
|
||||
words_txt: Optional[Path] = None
|
||||
_collate_fn = None
|
||||
|
||||
@property
|
||||
def collate_fn(self):
|
||||
if self._collate_fn is None:
|
||||
self._collate_fn = CollateFunc(self.pad_value)
|
||||
return self._collate_fn
|
||||
|
||||
def _reset_container(self):
|
||||
self.token_ids_list = []
|
||||
self.token_lens = []
|
||||
self.word_lens = []
|
||||
|
||||
@abstractmethod
|
||||
def _text_generator(self, text_source):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, text_source):
|
||||
"""
|
||||
Args:
|
||||
text_source may be text_file / word_seqs
|
||||
"""
|
||||
self._reset_container()
|
||||
for text in self._text_generator(text_source):
|
||||
self.word_lens.append(len(text.split()) + 1) # +1 for <eos>
|
||||
|
||||
token_ids = self.preprocessor(text)
|
||||
self.token_ids_list.append(token_ids)
|
||||
self.token_lens.append(len(token_ids) - 1) # -1 to remove <sos>
|
||||
|
||||
if len(self.token_ids_list) == self.batch_size:
|
||||
xs_pad, ys_pad = self.collate_fn(self.token_ids_list)
|
||||
|
||||
yield xs_pad, ys_pad, self.word_lens, self.token_lens
|
||||
self._reset_container()
|
||||
|
||||
if len(self.token_ids_list) != 0:
|
||||
xs_pad, ys_pad = self.collate_fn(self.token_ids_list)
|
||||
yield xs_pad, ys_pad, self.word_lens, self.token_lens
|
||||
self._reset_container()
|
||||
|
||||
|
||||
class TextFileDataIterator(AbsLMDataIterator):
|
||||
def __init__(self, dataset_option):
|
||||
super().__init__(**(dataset_option.__dict__))
|
||||
|
||||
def _text_generator(self, text_file):
|
||||
with open(text_file, "r") as f:
|
||||
for text in f:
|
||||
text = text.strip().split(maxsplit=1)[1]
|
||||
yield text
|
||||
|
||||
|
||||
class TokenidsDataIterator(AbsLMDataIterator):
|
||||
def __init__(self, dataset_option, numericalizer, src_word_table):
|
||||
super().__init__(**(dataset_option.__dict__))
|
||||
self.numericalizer = numericalizer
|
||||
self.src_word_table = src_word_table
|
||||
|
||||
def _text_generator(self, token_ids):
|
||||
for utt in token_ids:
|
||||
text = " ".join([self.src_word_table[token] for token in utt])
|
||||
text = text.upper()
|
||||
yield text
|
@ -773,6 +773,7 @@ def rescore_with_attention_decoder(
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
nnlm_evaluator=None,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
@ -854,11 +855,21 @@ def rescore_with_attention_decoder(
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if nnlm_evaluator is not None:
|
||||
aux_labels = k2.RaggedTensor(tokens_shape, nbest.fsa.aux_labels)
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
aux_labels = aux_labels.tolist()
|
||||
assert len(aux_labels) == len(token_ids)
|
||||
ppl_result = nnlm_evaluator.nll(aux_labels)
|
||||
nnlm_scores = -torch.tensor(ppl_result.nlls).to(attention_scores.device)
|
||||
assert nnlm_scores.shape[0] == len(token_ids)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
@ -881,6 +892,8 @@ def rescore_with_attention_decoder(
|
||||
+ n_scale * ngram_lm_scores.values
|
||||
+ a_scale * attention_scores
|
||||
)
|
||||
if nnlm_evaluator is not None:
|
||||
tot_scores = tot_scores + nnlm_scores
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user