From 81d386ef3e3bd68a30000b92a19de9676a0ec828 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 20 Apr 2023 12:27:43 +0800 Subject: [PATCH] Add compute_ppl.py and ngram_entropy_pruning.py (#1013) --- .../compute_ppl.py | 109 +++ icefall/shared/ngram_entropy_pruning.py | 630 ++++++++++++++++++ 2 files changed, 739 insertions(+) create mode 100755 egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py create mode 100755 icefall/shared/ngram_entropy_pruning.py diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py new file mode 100755 index 000000000..76306fc4c --- /dev/null +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/compute_ppl.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corp. (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./pruned_transducer_stateless7/compute_ppl.py \ + --ngram-lm-path ./download/lm/3gram_pruned_1e7.arpa + +""" + + +import argparse +import logging +import math +from typing import Dict, List, Optional, Tuple + +import kenlm +import torch +from asr_datamodule import GigaSpeechAsrDataModule + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--ngram-lm-path", + type=str, + default="download/lm/3gram_pruned_1e7.arpa", + help="The lang dir containing word table and LG graph", + ) + + return parser + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: kenlm.Model, +) -> Dict[str, float]: + """ + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + A ngram lm of kenlm.Model object. + Returns: + Return the perplexity of the giving dataset. + """ + sum_score_log = 0 + sum_n = 0 + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + for text in texts: + sum_n += len(text.split()) + 1 + sum_score_log += -1 * model.score(text) + + ppl = math.pow(10.0, sum_score_log / sum_n) + + return ppl + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + logging.info("About to load ngram LM") + model = kenlm.Model(args.ngram_lm_path) + + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + ppl = decode_dataset( + dl=test_dl, + model=model, + ) + logging.info(f"{test_set} PPL: {ppl}") + + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 100755 index 000000000..b1ebee9ea --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./ngram_entropy_pruning.py \ + -threshold 1e-8 \ + -lm download/lm/4gram.arpa \ + -write-lm download/lm/4gram_pruned_1e8.arpa + +This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. +This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' +in the same way as SRILM. +""" + + +import argparse +import gzip +import logging +import math +import re +from collections import OrderedDict, defaultdict +from enum import Enum, unique +from io import StringIO + +parser = argparse.ArgumentParser( + description=""" + Prune an n-gram language model based on the relative entropy + between the original and the pruned model, based on Andreas Stolcke's paper. + An n-gram entry is removed, if the removal causes (training set) perplexity + of the model to increase by less than threshold relative. + + The command takes an arpa file and a pruning threshold as input, + and outputs a pruned arpa file. + """ +) +parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") +parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") +parser.add_argument( + "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" +) +parser.add_argument( + "-minorder", + type=int, + default=1, + help="The minorder parameter limits pruning to ngrams of that length and above.", +) +parser.add_argument( + "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" +) +parser.add_argument( + "-verbose", + type=int, + default=2, + choices=[0, 1, 2, 3, 4, 5], + help="Verbose level, where 0 is most noisy; 5 is most silent", +) +args = parser.parse_args() + +default_encoding = args.encoding +logging.basicConfig( + format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", + level=args.verbose * 10, +) + + +class Context(dict): + """ + This class stores data for a context h. + It behaves like a python dict object, except that it has several + additional attributes. + """ + + def __init__(self): + super().__init__() + self.log_bo = None + + +class Arpa: + """ + This is a class that implement the data structure of an APRA LM. + It (as well as some other classes) is modified based on the library + by Stefan Fischer: + https://github.com/sfischer13/python-arpa + """ + + UNK = "" + SOS = "" + EOS = "" + FLOAT_NDIGITS = 7 + base = 10 + + @staticmethod + def _check_input(my_input): + if not my_input: + raise ValueError + elif isinstance(my_input, tuple): + return my_input + elif isinstance(my_input, list): + return tuple(my_input) + elif isinstance(my_input, str): + return tuple(my_input.strip().split(" ")) + else: + raise ValueError + + @staticmethod + def _check_word(input_word): + if not isinstance(input_word, str): + raise ValueError + if " " in input_word: + raise ValueError + + def _replace_unks(self, words): + return tuple((w if w in self else self._unk) for w in words) + + def __init__(self, path=None, encoding=None, unk=None): + self._counts = OrderedDict() + self._ngrams = ( + OrderedDict() + ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) + self._vocabulary = set() + if unk is None: + self._unk = self.UNK + + if path is not None: + self.loadf(path, encoding) + + def __contains__(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] + + def contains_word(self, word): + self._check_word(word) + return word in self._vocabulary + + def add_count(self, order, count): + self._counts[order] = count + self._ngrams[order - 1] = defaultdict(Context) + + def update_counts(self): + for order in range(1, self.order() + 1): + count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) + if count > 0: + self._counts[order] = count + + def add_entry(self, ngram, p, bo=None, order=None): + # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + + # Note that p and bo here are in fact in the log domain (self.base = 10) + h_context = self._ngrams[len(h)][h] + h_context[w] = p + if bo is not None: + self._ngrams[len(ngram)][ngram].log_bo = bo + + for word in ngram: + self._vocabulary.add(word) + + def counts(self): + return sorted(self._counts.items()) + + def order(self): + return max(self._counts.keys(), default=None) + + def vocabulary(self, sort=True): + if sort: + return sorted(self._vocabulary) + else: + return self._vocabulary + + def _entries(self, order): + return ( + self._entry(h, w) + for h, wlist in self._ngrams[order - 1].items() + for w in wlist + ) + + def _entry(self, h, w): + # return the entry for the ngram (h, w) + ngram = h + (w,) + log_p = self._ngrams[len(h)][h][w] + log_bo = self._log_bo(ngram) + if log_bo is not None: + return ( + round(log_p, self.FLOAT_NDIGITS), + ngram, + round(log_bo, self.FLOAT_NDIGITS), + ) + else: + return round(log_p, self.FLOAT_NDIGITS), ngram + + def _log_bo(self, ngram): + if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: + return self._ngrams[len(ngram)][ngram].log_bo + else: + return None + + def _log_p(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: + return self._ngrams[len(h)][h][w] + else: + return None + + def log_p_raw(self, ngram): + log_p = self._log_p(ngram) + if log_p is not None: + return log_p + else: + if len(ngram) == 1: + raise KeyError + else: + log_bo = self._log_bo(ngram[:-1]) + if log_bo is None: + log_bo = 0 + return log_bo + self.log_p_raw(ngram[1:]) + + def log_joint_prob(self, sequence): + # Compute the joint prob of the sequence based on the chain rule + # Note that sequence should be a tuple of strings + # + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 + + log_joint_p = 0 + seq = sequence + while len(seq) > 0: + log_joint_p += self.log_p_raw(seq) + seq = seq[:-1] + + # If we're computing the marginal probability of the unigram + # context we have to look up instead since the former + # has prob = 0. + if len(seq) == 1 and seq[0] == self.SOS: + seq = (self.EOS,) + + return log_joint_p + + def set_new_context(self, h): + old_context = self._ngrams[len(h)][h] + self._ngrams[len(h)][h] = Context() + return old_context + + def log_p(self, ngram): + words = self._check_input(ngram) + if self._unk: + words = self._replace_unks(words) + return self.log_p_raw(words) + + def log_s(self, sentence, sos=SOS, eos=EOS): + words = self._check_input(sentence) + if self._unk: + words = self._replace_unks(words) + if sos: + words = (sos,) + words + if eos: + words = words + (eos,) + result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) + if sos: + result = result - self.log_p_raw(words[:1]) + return result + + def p(self, ngram): + return self.base ** self.log_p(ngram) + + def s(self, sentence): + return self.base ** self.log_s(sentence) + + def write(self, fp): + fp.write("\n\\data\\\n") + for order, count in self.counts(): + fp.write("ngram {}={}\n".format(order, count)) + fp.write("\n") + for order, _ in self.counts(): + fp.write("\\{}-grams:\n".format(order)) + for e in self._entries(order): + prob = e[0] + ngram = " ".join(e[1]) + if len(e) == 2: + fp.write("{}\t{}\n".format(prob, ngram)) + elif len(e) == 3: + backoff = e[2] + fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) + else: + raise ValueError + fp.write("\n") + fp.write("\\end\\\n") + + +class ArpaParser: + """ + This is a class that implement a parser of an arpa file + """ + + @unique + class State(Enum): + DATA = 1 + COUNT = 2 + HEADER = 3 + ENTRY = 4 + + re_count = re.compile(r"^ngram (\d+)=(\d+)$") + re_header = re.compile(r"^\\(\d+)-grams:$") + re_entry = re.compile( + "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" + "\t" + "(\\S+( \\S+)*)" + "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" + ) + + def _parse(self, fp): + self._result = [] + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + for line in fp: + line = line.strip() + if self._state == self.State.DATA: + self._data(line) + elif self._state == self.State.COUNT: + self._count(line) + elif self._state == self.State.HEADER: + self._header(line) + elif self._state == self.State.ENTRY: + self._entry(line) + if self._state != self.State.DATA: + raise Exception(line) + return self._result + + def _data(self, line): + if line == "\\data\\": + self._state = self.State.COUNT + self._tmp_model = Arpa() + else: + pass # skip comment line + + def _count(self, line): + match = self.re_count.match(line) + if match: + order = match.group(1) + count = match.group(2) + self._tmp_model.add_count(int(order), int(count)) + elif not line: + self._state = self.State.HEADER # there are no counts + else: + raise Exception(line) + + def _header(self, line): + match = self.re_header.match(line) + if match: + self._state = self.State.ENTRY + self._tmp_order = int(match.group(1)) + elif line == "\\end\\": + self._result.append(self._tmp_model) + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + elif not line: + pass # skip empty line + else: + raise Exception(line) + + def _entry(self, line): + match = self.re_entry.match(line) + if match: + p = self._float_or_int(match.group(1)) + ngram = tuple(match.group(4).split(" ")) + bo_match = match.group(7) + bo = self._float_or_int(bo_match) if bo_match else None + self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) + elif not line: + self._state = self.State.HEADER # last entry + else: + raise Exception(line) + + @staticmethod + def _float_or_int(s): + f = float(s) + i = int(f) + if str(i) == s: # don't drop trailing ".0" + return i + else: + return f + + def load(self, fp): + """Deserialize fp (a file-like object) to a Python object.""" + return self._parse(fp) + + def loadf(self, path, encoding=None): + """Deserialize path (.arpa, .gz) to a Python object.""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + else: + with open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + + def loads(self, s): + """Deserialize s (a str) to a Python object.""" + with StringIO(s) as f: + return self.load(f) + + def dump(self, obj, fp): + """Serialize obj to fp (a file-like object) in ARPA format.""" + obj.write(fp) + + def dumpf(self, obj, path, encoding=None): + """Serialize obj to path in ARPA format (.arpa, .gz).""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="wt", encoding=encoding) as f: + return self.dump(obj, f) + else: + with open(path, mode="wt", encoding=encoding) as f: + self.dump(obj, f) + + def dumps(self, obj): + """Serialize obj to an ARPA formatted str.""" + with StringIO() as f: + self.dump(obj, f) + return f.getvalue() + + +def add_log_p(prev_log_sum, log_p, base): + return math.log(base**log_p + base**prev_log_sum, base) + + +def compute_numerator_denominator(lm, h): + log_sum_seen_h = -math.inf + log_sum_seen_h_lower = -math.inf + base = lm.base + for w, log_p in lm._ngrams[len(h)][h].items(): + log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) + + ngram = h + (w,) + log_p_lower = lm.log_p_raw(ngram[1:]) + log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) + + numerator = 1.0 - base**log_sum_seen_h + denominator = 1.0 - base**log_sum_seen_h_lower + return numerator, denominator + + +def prune(lm, threshold, minorder): + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 + + for i in range( + lm.order(), max(minorder - 1, 1), -1 + ): # i is the order of the ngram (h, w) + logging.info("processing %d-grams ..." % i) + count_pruned_ngrams = 0 + + h_dict = lm._ngrams[i - 1] + for h in list(h_dict.keys()): + # old backoff weight, BOW(h) + log_bow = lm._log_bo(h) + if log_bow is None: + log_bow = 0 + + # Compute numerator and denominator of the backoff weight, + # so that we can quickly compute the BOW adjustment due to + # leaving out one prob. + numerator, denominator = compute_numerator_denominator(lm, h) + + # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 + + # Compute the marginal probability of the context, P(h) + h_log_p = lm.log_joint_prob(h) + + all_pruned = True + pruned_w_set = set() + + for w, log_p in h_dict[h].items(): + ngram = h + (w,) + + # lower-order estimate for ngramProb, P(w|h') + backoff_prob = lm.log_p_raw(ngram[1:]) + + # Compute BOW after removing ngram, BOW'(h) + new_log_bow = math.log( + numerator + lm.base**log_p, lm.base + ) - math.log(denominator + lm.base**backoff_prob, lm.base) + + # Compute change in entropy due to removal of ngram + delta_prob = backoff_prob + new_log_bow - log_p + delta_entropy = -(lm.base**h_log_p) * ( + (lm.base**log_p) * delta_prob + + numerator * (new_log_bow - log_bow) + ) + + # compute relative change in model (training set) perplexity + perp_change = lm.base**delta_entropy - 1.0 + + pruned = threshold > 0 and perp_change < threshold + + # Make sure we don't prune ngrams whose backoff nodes are needed + if ( + pruned + and len(ngram) in lm._ngrams + and len(lm._ngrams[len(ngram)][ngram]) > 0 + ): + pruned = False + + logging.debug( + "CONTEXT " + + str(h) + + " WORD " + + w + + " CONTEXTPROB %f " % h_log_p + + " OLDPROB %f " % log_p + + " NEWPROB %f " % (backoff_prob + new_log_bow) + + " DELTA-H %f " % delta_entropy + + " DELTA-LOGP %f " % delta_prob + + " PPL-CHANGE %f " % perp_change + + " PRUNED " + + str(pruned) + ) + + if pruned: + pruned_w_set.add(w) + count_pruned_ngrams += 1 + else: + all_pruned = False + + # If we removed all ngrams for this context we can + # remove the context itself, but only if the present + # context is not a prefix to a longer one. + if all_pruned and len(pruned_w_set) == len(h_dict[h]): + del h_dict[ + h + ] # this context h is no longer needed, as its ngram prob is stored at its own context h' + elif len(pruned_w_set) > 0: + # The pruning for this context h is actually done here + old_context = lm.set_new_context(h) + + for w, p_w in old_context.items(): + if w not in pruned_w_set: + lm.add_entry( + h + (w,), p_w + ) # the entry hw is stored at the context h + + # We need to recompute the back-off weight, but + # this can only be done after completing the pruning + # of the lower-order ngrams. + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 + + logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) + + # recompute backoff weights + for i in range( + max(minorder - 1, 1) + 1, lm.order() + 1 + ): # be careful of this order: from low- to high-order + for h in lm._ngrams[i - 1]: + numerator, denominator = compute_numerator_denominator(lm, h) + new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) + lm._ngrams[len(h)][h].log_bo = new_log_bow + + # update counts + lm.update_counts() + + return + + +def check_h_is_valid(lm, h): + sum_under_h = sum( + [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] + ) + if abs(sum_under_h - 1.0) > 1e-6: + logging.info("warning: %s %f" % (str(h), sum_under_h)) + return False + else: + return True + + +def validate_lm(lm): + # sanity check if the conditional probability sums to one under each context h + for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) + logging.info("validating %d-grams ..." % i) + h_dict = lm._ngrams[i - 1] + for h in h_dict.keys(): + check_h_is_valid(lm, h) + + +def compare_two_apras(path1, path2): + pass + + +if __name__ == "__main__": + # load an arpa file + logging.info("Loading the arpa file from %s" % args.lm) + parser = ArpaParser() + models = parser.loadf(args.lm, encoding=default_encoding) + lm = models[0] # ARPA files may contain several models. + logging.info("Stats before pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + + # prune it, the language model will be modified in-place + logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) + prune(lm, args.threshold, args.minorder) + + # validate_lm(lm) + + # write the arpa language model to a file + logging.info("Stats after pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + logging.info("Saving the pruned arpa file to %s" % args.write_lm) + parser.dumpf(lm, args.write_lm, encoding=default_encoding) + logging.info("Done.")