diff --git a/egs/librispeech/ASR/local/ngram_entropy_pruning.py b/egs/librispeech/ASR/local/ngram_entropy_pruning.py deleted file mode 100644 index d0ffa92f6..000000000 --- a/egs/librispeech/ASR/local/ngram_entropy_pruning.py +++ /dev/null @@ -1,627 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) -# Apache 2.0. - -# This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' -# in the same way as SRILM. - -################################################ -# Useful links/References: -################################################ -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2124 -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 -# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 -# https://github.com/sfischer13/python-arpa - -################################################ -# How to use: -################################################ -# python3 ngram_entropy_pruning.py -threshold $threshold -lm $input_lm -write-lm $pruned_lm - -################################################ -# SRILM commands: -################################################ -# to_prune_lm=egs/swbd/s5c/data/local/lm/sw1.o3g.kn.gz -# vocab=egs/swbd/s5c/data/local/lm/wordlist -# order=3 -# oov_symbol="" -# threshold=4.7e-5 -# pruned_lm=temp.${threshold}.gz -# ngram -unk -map-unk "$oov_symbol" -vocab $vocab -order $order -prune ${threshold} -lm ${to_prune_lm} -write-lm ${pruned_lm} -# -# lm= -# ngram -unk -lm $lm -ppl heldout -# ngram -unk -lm $lm -ppl heldout -debug 3 - -import argparse -import logging -import math - -import gzip -from io import StringIO -from collections import OrderedDict -from collections import defaultdict -from enum import Enum, unique -import re - -parser = argparse.ArgumentParser(description=""" - Prune an n-gram language model based on the relative entropy - between the original and the pruned model, based on Andreas Stolcke's paper. - An n-gram entry is removed, if the removal causes (training set) perplexity - of the model to increase by less than threshold relative. - - The command takes an arpa file and a pruning threshold as input, - and outputs a pruned arpa file. - """) -parser.add_argument("-threshold", - type=float, - default=1e-6, - help="Order of n-gram") -parser.add_argument("-lm", - type=str, - default=None, - help="Path to the input arpa file") -parser.add_argument("-write-lm", - type=str, - default=None, - help="Path to output arpa file after pruning") -parser.add_argument("-minorder", - type=int, - default=1, - help="The minorder parameter limits pruning to " - "ngrams of that length and above.") -parser.add_argument("-encoding", - type=str, - default="utf-8", - help="Encoding of the arpa file") -parser.add_argument("-verbose", - type=int, - default=2, - choices=[0, 1, 2, 3, 4, 5], - help="Verbose level, where " - "0 is most noisy; " - "5 is most silent") -args = parser.parse_args() - -default_encoding = args.encoding -logging.basicConfig( - format= - "%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", - level=args.verbose * 10) - - -class Context(dict): - """ - This class stores data for a context h. - It behaves like a python dict object, except that it has several - additional attributes. - """ - def __init__(self): - super().__init__() - self.log_bo = None - - -class Arpa: - """ - This is a class that implement the data structure of an APRA LM. - It (as well as some other classes) is modified based on the library - by Stefan Fischer: - https://github.com/sfischer13/python-arpa - """ - - UNK = '' - SOS = '' - EOS = '' - FLOAT_NDIGITS = 7 - base = 10 - - @staticmethod - def _check_input(my_input): - if not my_input: - raise ValueError - elif isinstance(my_input, tuple): - return my_input - elif isinstance(my_input, list): - return tuple(my_input) - elif isinstance(my_input, str): - return tuple(my_input.strip().split(' ')) - else: - raise ValueError - - @staticmethod - def _check_word(input_word): - if not isinstance(input_word, str): - raise ValueError - if ' ' in input_word: - raise ValueError - - def _replace_unks(self, words): - return tuple((w if w in self else self._unk) for w in words) - - def __init__(self, path=None, encoding=None, unk=None): - self._counts = OrderedDict() - self._ngrams = OrderedDict( - ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) - self._vocabulary = set() - if unk is None: - self._unk = self.UNK - - if path is not None: - self.loadf(path, encoding) - - def __contains__(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] - - def contains_word(self, word): - self._check_word(word) - return word in self._vocabulary - - def add_count(self, order, count): - self._counts[order] = count - self._ngrams[order - 1] = defaultdict(Context) - - def update_counts(self): - for order in range(1, self.order() + 1): - count = sum( - [len(wlist) for _, wlist in self._ngrams[order - 1].items()]) - if count > 0: - self._counts[order] = count - - def add_entry(self, ngram, p, bo=None, order=None): - # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - - # Note that p and bo here are in fact in the log domain (self.base = 10) - h_context = self._ngrams[len(h)][h] - h_context[w] = p - if bo is not None: - self._ngrams[len(ngram)][ngram].log_bo = bo - - for word in ngram: - self._vocabulary.add(word) - - def counts(self): - return sorted(self._counts.items()) - - def order(self): - return max(self._counts.keys(), default=None) - - def vocabulary(self, sort=True): - if sort: - return sorted(self._vocabulary) - else: - return self._vocabulary - - def _entries(self, order): - return (self._entry(h, w) - for h, wlist in self._ngrams[order - 1].items() for w in wlist) - - def _entry(self, h, w): - # return the entry for the ngram (h, w) - ngram = h + (w, ) - log_p = self._ngrams[len(h)][h][w] - log_bo = self._log_bo(ngram) - if log_bo is not None: - return round(log_p, self.FLOAT_NDIGITS), ngram, round( - log_bo, self.FLOAT_NDIGITS) - else: - return round(log_p, self.FLOAT_NDIGITS), ngram - - def _log_bo(self, ngram): - if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: - return self._ngrams[len(ngram)][ngram].log_bo - else: - return None - - def _log_p(self, ngram): - h = ngram[:-1] # h is a tuple - w = ngram[-1] # w is a string/word - if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: - return self._ngrams[len(h)][h][w] - else: - return None - - def log_p_raw(self, ngram): - log_p = self._log_p(ngram) - if log_p is not None: - return log_p - else: - if len(ngram) == 1: - raise KeyError - else: - log_bo = self._log_bo(ngram[:-1]) - if log_bo is None: - log_bo = 0 - return log_bo + self.log_p_raw(ngram[1:]) - - def log_joint_prob(self, sequence): - # Compute the joint prob of the sequence based on the chain rule - # Note that sequence should be a tuple of strings - # - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 - - log_joint_p = 0 - seq = sequence - while len(seq) > 0: - log_joint_p += self.log_p_raw(seq) - seq = seq[:-1] - - # If we're computing the marginal probability of the unigram - # context we have to look up instead since the former - # has prob = 0. - if len(seq) == 1 and seq[0] == self.SOS: - seq = (self.EOS, ) - - return log_joint_p - - def set_new_context(self, h): - old_context = self._ngrams[len(h)][h] - self._ngrams[len(h)][h] = Context() - return old_context - - def log_p(self, ngram): - words = self._check_input(ngram) - if self._unk: - words = self._replace_unks(words) - return self.log_p_raw(words) - - def log_s(self, sentence, sos=SOS, eos=EOS): - words = self._check_input(sentence) - if self._unk: - words = self._replace_unks(words) - if sos: - words = (sos, ) + words - if eos: - words = words + (eos, ) - result = sum( - self.log_p_raw(words[:i]) for i in range(1, - len(words) + 1)) - if sos: - result = result - self.log_p_raw(words[:1]) - return result - - def p(self, ngram): - return self.base**self.log_p(ngram) - - def s(self, sentence): - return self.base**self.log_s(sentence) - - def write(self, fp): - fp.write('\n\\data\\\n') - for order, count in self.counts(): - fp.write('ngram {}={}\n'.format(order, count)) - fp.write('\n') - for order, _ in self.counts(): - fp.write('\\{}-grams:\n'.format(order)) - for e in self._entries(order): - prob = e[0] - ngram = ' '.join(e[1]) - if len(e) == 2: - fp.write('{}\t{}\n'.format(prob, ngram)) - elif len(e) == 3: - backoff = e[2] - fp.write('{}\t{}\t{}\n'.format(prob, ngram, backoff)) - else: - raise ValueError - fp.write('\n') - fp.write('\\end\\\n') - - -class ArpaParser: - """ - This is a class that implement a parser of an arpa file - """ - @unique - class State(Enum): - DATA = 1 - COUNT = 2 - HEADER = 3 - ENTRY = 4 - - re_count = re.compile(r'^ngram (\d+)=(\d+)$') - re_header = re.compile(r'^\\(\d+)-grams:$') - re_entry = re.compile('^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)' - '\t' - '(\\S+( \\S+)*)' - '(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$') - - def _parse(self, fp): - self._result = [] - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - for line in fp: - line = line.strip() - if self._state == self.State.DATA: - self._data(line) - elif self._state == self.State.COUNT: - self._count(line) - elif self._state == self.State.HEADER: - self._header(line) - elif self._state == self.State.ENTRY: - self._entry(line) - if self._state != self.State.DATA: - raise Exception(line) - return self._result - - def _data(self, line): - if line == '\\data\\': - self._state = self.State.COUNT - self._tmp_model = Arpa() - else: - pass # skip comment line - - def _count(self, line): - match = self.re_count.match(line) - if match: - order = match.group(1) - count = match.group(2) - self._tmp_model.add_count(int(order), int(count)) - elif not line: - self._state = self.State.HEADER # there are no counts - else: - raise Exception(line) - - def _header(self, line): - match = self.re_header.match(line) - if match: - self._state = self.State.ENTRY - self._tmp_order = int(match.group(1)) - elif line == '\\end\\': - self._result.append(self._tmp_model) - self._state = self.State.DATA - self._tmp_model = None - self._tmp_order = None - elif not line: - pass # skip empty line - else: - raise Exception(line) - - def _entry(self, line): - match = self.re_entry.match(line) - if match: - p = self._float_or_int(match.group(1)) - ngram = tuple(match.group(4).split(' ')) - bo_match = match.group(7) - bo = self._float_or_int(bo_match) if bo_match else None - self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) - elif not line: - self._state = self.State.HEADER # last entry - else: - raise Exception(line) - - @staticmethod - def _float_or_int(s): - f = float(s) - i = int(f) - if str(i) == s: # don't drop trailing ".0" - return i - else: - return f - - def load(self, fp): - """Deserialize fp (a file-like object) to a Python object.""" - return self._parse(fp) - - def loadf(self, path, encoding=None): - """Deserialize path (.arpa, .gz) to a Python object.""" - path = str(path) - if path.endswith('.gz'): - with gzip.open(path, mode='rt', encoding=encoding) as f: - return self.load(f) - else: - with open(path, mode='rt', encoding=encoding) as f: - return self.load(f) - - def loads(self, s): - """Deserialize s (a str) to a Python object.""" - with StringIO(s) as f: - return self.load(f) - - def dump(self, obj, fp): - """Serialize obj to fp (a file-like object) in ARPA format.""" - obj.write(fp) - - def dumpf(self, obj, path, encoding=None): - """Serialize obj to path in ARPA format (.arpa, .gz).""" - path = str(path) - if path.endswith('.gz'): - with gzip.open(path, mode='wt', encoding=encoding) as f: - return self.dump(obj, f) - else: - with open(path, mode='wt', encoding=encoding) as f: - self.dump(obj, f) - - def dumps(self, obj): - """Serialize obj to an ARPA formatted str.""" - with StringIO() as f: - self.dump(obj, f) - return f.getvalue() - - -def add_log_p(prev_log_sum, log_p, base): - return math.log(base**log_p + base**prev_log_sum, base) - - -def compute_numerator_denominator(lm, h): - log_sum_seen_h = -math.inf - log_sum_seen_h_lower = -math.inf - base = lm.base - for w, log_p in lm._ngrams[len(h)][h].items(): - log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) - - ngram = h + (w, ) - log_p_lower = lm.log_p_raw(ngram[1:]) - log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, - base) - - numerator = 1.0 - base**log_sum_seen_h - denominator = 1.0 - base**log_sum_seen_h_lower - return numerator, denominator - - -def prune(lm, threshold, minorder): - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 - - for i in range(lm.order(), max(minorder - 1, 1), - -1): # i is the order of the ngram (h, w) - logging.info("processing %d-grams ..." % i) - count_pruned_ngrams = 0 - - h_dict = lm._ngrams[i - 1] - for h in list(h_dict.keys()): - # old backoff weight, BOW(h) - log_bow = lm._log_bo(h) - if log_bow is None: - log_bow = 0 - - # Compute numerator and denominator of the backoff weight, - # so that we can quickly compute the BOW adjustment due to - # leaving out one prob. - numerator, denominator = compute_numerator_denominator(lm, h) - - # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 - - # Compute the marginal probability of the context, P(h) - h_log_p = lm.log_joint_prob(h) - - all_pruned = True - pruned_w_set = set() - - for w, log_p in h_dict[h].items(): - ngram = h + (w, ) - - # lower-order estimate for ngramProb, P(w|h') - backoff_prob = lm.log_p_raw(ngram[1:]) - - # Compute BOW after removing ngram, BOW'(h) - new_log_bow = math.log(numerator + lm.base ** log_p, lm.base) - \ - math.log(denominator + lm.base ** backoff_prob, lm.base) - - # Compute change in entropy due to removal of ngram - delta_prob = backoff_prob + new_log_bow - log_p - delta_entropy = - (lm.base ** h_log_p) * \ - ((lm.base ** log_p) * delta_prob + - numerator * (new_log_bow - log_bow)) - - # compute relative change in model (training set) perplexity - perp_change = lm.base**delta_entropy - 1.0 - - pruned = threshold > 0 and perp_change < threshold - - # Make sure we don't prune ngrams whose backoff nodes are needed - if pruned and \ - len(ngram) in lm._ngrams and \ - len(lm._ngrams[len(ngram)][ngram]) > 0: - pruned = False - - logging.debug("CONTEXT " + str(h) + " WORD " + w + - " CONTEXTPROB %f " % h_log_p + - " OLDPROB %f " % log_p + " NEWPROB %f " % - (backoff_prob + new_log_bow) + - " DELTA-H %f " % delta_entropy + - " DELTA-LOGP %f " % delta_prob + - " PPL-CHANGE %f " % perp_change + " PRUNED " + - str(pruned)) - - if pruned: - pruned_w_set.add(w) - count_pruned_ngrams += 1 - else: - all_pruned = False - - # If we removed all ngrams for this context we can - # remove the context itself, but only if the present - # context is not a prefix to a longer one. - if all_pruned and len(pruned_w_set) == len(h_dict[h]): - del h_dict[ - h] # this context h is no longer needed, as its ngram prob is stored at its own context h' - elif len(pruned_w_set) > 0: - # The pruning for this context h is actually done here - old_context = lm.set_new_context(h) - - for w, p_w in old_context.items(): - if w not in pruned_w_set: - lm.add_entry( - h + (w, ), - p_w) # the entry hw is stored at the context h - - # We need to recompute the back-off weight, but - # this can only be done after completing the pruning - # of the lower-order ngrams. - # Reference: - # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 - - logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) - - # recompute backoff weights - for i in range(max(minorder - 1, 1) + 1, - lm.order() + - 1): # be careful of this order: from low- to high-order - for h in lm._ngrams[i - 1]: - numerator, denominator = compute_numerator_denominator(lm, h) - new_log_bow = math.log(numerator, lm.base) - math.log( - denominator, lm.base) - lm._ngrams[len(h)][h].log_bo = new_log_bow - - # update counts - lm.update_counts() - - return - - -def check_h_is_valid(lm, h): - sum_under_h = sum( - [lm.base**lm.log_p_raw(h + (w, )) for w in lm.vocabulary(sort=False)]) - if abs(sum_under_h - 1.0) > 1e-6: - logging.info("warning: %s %f" % (str(h), sum_under_h)) - return False - else: - return True - - -def validate_lm(lm): - # sanity check if the conditional probability sums to one under each context h - for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) - logging.info("validating %d-grams ..." % i) - h_dict = lm._ngrams[i - 1] - for h in h_dict.keys(): - check_h_is_valid(lm, h) - - -def compare_two_apras(path1, path2): - pass - - -if __name__ == '__main__': - # load an arpa file - logging.info("Loading the arpa file from %s" % args.lm) - parser = ArpaParser() - models = parser.loadf(args.lm, encoding=default_encoding) - lm = models[0] # ARPA files may contain several models. - logging.info("Stats before pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - - # prune it, the language model will be modified in-place - logging.info("Start pruning the model with threshold=%.3E..." % - args.threshold) - prune(lm, args.threshold, args.minorder) - - # validate_lm(lm) - - # write the arpa language model to a file - logging.info("Stats after pruning:") - for i, cnt in lm.counts(): - logging.info("ngram %d=%d" % (i, cnt)) - logging.info("Saving the pruned arpa file to %s" % args.write_lm) - parser.dumpf(lm, args.write_lm, encoding=default_encoding) - logging.info("Done.") diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 375da0d79..c8e093177 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -159,51 +159,13 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then -lm data/lang_bpe/P.arpa fi - # TODO: Use egs/wsj/s5/utils/lang/ngram_entropy_pruning.py - # from kaldi to prune P if it causes OOM later - - if [ ! -f data/lang_bpe/P-no-prune.fst.txt ]; then + if [ ! -f data/lang_bpe/P.fst.txt ]; then python3 -m kaldilm \ --read-symbol-table="data/lang_bpe/tokens.txt" \ --disambig-symbol='#0' \ --max-order=2 \ - data/lang_bpe/P.arpa > data/lang_bpe/P-no-prune.fst.txt + data/lang_bpe/P.arpa > data/lang_bpe/P.fst.txt fi - - thresholds=( - 1e-6 - 1e-7 - ) - for threshold in ${thresholds[@]}; do - if [ ! -f data/lang_bpe/P-pruned.${threshold}.arpa ]; then - python3 ./local/ngram_entropy_pruning.py \ - -threshold $threshold \ - -lm data/lang_bpe/P.arpa \ - -write-lm data/lang_bpe/P-pruned.${threshold}.arpa - fi - - if [ ! -f data/lang_bpe/P-pruned.${threshold}.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_bpe/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - data/lang_bpe/P-pruned.${threshold}.arpa > data/lang_bpe/P-pruned.${threshold}.fst.txt - fi - done - - if [ ! -f data/lang_bpe/P-uni.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_bpe/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=1 \ - data/lang_bpe/P.arpa > data/lang_bpe/P-uni.fst.txt - fi - - ( cd data/lang_bpe; - # ln -sfv P-pruned.1e-6.fst.txt P.fst.txt - ln -sfv P-no-prune.fst.txt P.fst.txt - ) - rm -fv data/lang_bpe/P.pt data/lang_bpe/ctc_topo_P.pt fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then diff --git a/requirements.txt b/requirements.txt index a54edf118..710048fed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ kaldilm kaldialign sentencepiece>=0.1.96 +tensorboard