#!/usr/bin/env python3 # Copyright 2016 Johns Hopkins University (Author: Daniel Povey) # 2018 Ruizhe Huang # Apache 2.0. # This is an implementation of computing Kneser-Ney smoothed language model # in the same way as srilm. This is a back-off, unmodified version of # Kneser-Ney smoothing, which produces the same results as the following # command (as an example) of srilm: # # $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ # -text corpus.txt -lm lm.arpa # # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html import argparse import io import math import os import re import sys from collections import Counter, defaultdict parser = argparse.ArgumentParser( description=""" Generate kneser-ney language model as arpa format. By default, it will read the corpus from standard input, and output to standard output. """ ) parser.add_argument( "-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram", ) parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") parser.add_argument( "-lm", type=str, default=None, help="Path to output arpa file for language models" ) parser.add_argument( "-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level" ) args = parser.parse_args() # For encoding-agnostic scripts, we assume byte stream as input. # Need to be very careful about the use of strip() and split() # in this case, because there is a latin-1 whitespace character # (nbsp) which is part of the unicode encoding range. # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 default_encoding = "latin-1" strip_chars = " \t\r\n" whitespace = re.compile("[ \t]+") class CountsForHistory: # This class (which is more like a struct) stores the counts seen in a # particular history-state. It is used inside class NgramCounts. # It really does the job of a dict from int to float, but it also # keeps track of the total count. def __init__(self): # The 'lambda: defaultdict(float)' is an anonymous function taking no # arguments that returns a new defaultdict(float). self.word_to_count = defaultdict(int) # using a set to count the number of unique contexts self.word_to_context = defaultdict(set) self.word_to_f = dict() # discounted probability self.word_to_bow = dict() # back-off weight self.total_count = 0 def words(self): return self.word_to_count.keys() def __str__(self): # e.g. returns ' total=12: 3->4, 4->6, -1->2' return " total={0}: {1}".format( str(self.total_count), ", ".join( [ "{0} -> {1}".format(word, count) for word, count in self.word_to_count.items() ] ), ) def add_count(self, predicted_word, context_word, count): assert count >= 0 self.total_count += count self.word_to_count[predicted_word] += count if context_word is not None: self.word_to_context[predicted_word].add(context_word) class NgramCounts: # A note on data-structure. Firstly, all words are represented as # integers. We store n-gram counts as an array, indexed by (history-length # == n-gram order minus one) (note: python calls arrays "lists") of dicts # from histories to counts, where histories are arrays of integers and # "counts" are dicts from integer to float. For instance, when # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. def __init__(self, ngram_order, bos_symbol="", eos_symbol=""): assert ngram_order >= 2 self.ngram_order = ngram_order self.bos_symbol = bos_symbol self.eos_symbol = eos_symbol self.counts = [] for n in range(ngram_order): self.counts.append(defaultdict(lambda: CountsForHistory())) self.d = [] # list of discounting factor for each order of ngram # adds a raw count (called while processing input data). # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be # 1. def add_count(self, history, predicted_word, context_word, count): self.counts[len(history)][history].add_count( predicted_word, context_word, count ) # 'line' is a string containing a sequence of integer word-ids. # This function adds the un-smoothed counts from this line of text. def add_raw_counts_from_line(self, line): if line == "": words = [self.bos_symbol, self.eos_symbol] else: words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] for i in range(len(words)): for n in range(1, self.ngram_order + 1): if i + n > len(words): break ngram = words[i : i + n] predicted_word = ngram[-1] history = tuple(ngram[:-1]) if i == 0 or n == self.ngram_order: context_word = None else: context_word = words[i - 1] self.add_count(history, predicted_word, context_word, 1) def add_raw_counts_from_standard_input(self): lines_processed = 0 # byte stream as input infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) for line in infile: line = line.strip(strip_chars) self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print( "make_phone_lm.py: processed {0} lines of input".format( lines_processed ), file=sys.stderr, ) def add_raw_counts_from_file(self, filename): lines_processed = 0 with open(filename, encoding=default_encoding) as fp: for line in fp: line = line.strip(strip_chars) self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print( "make_phone_lm.py: processed {0} lines of input".format( lines_processed ), file=sys.stderr, ) def cal_discounting_constants(self): # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). # This constant is used similarly to absolute discounting. # Return value: d is a list of floats, where d[N+1] = D_N # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, # but perhaps this is not the case for some other scenarios. self.d = [0] for n in range(1, self.ngram_order): this_order_counts = self.counts[n] n1 = 0 n2 = 0 for hist, counts_for_hist in this_order_counts.items(): stat = Counter(counts_for_hist.word_to_count.values()) n1 += stat[1] n2 += stat[2] assert n1 + 2 * n2 > 0 # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, # which could happen if the number of symbols is small. # Otherwise, zero discounting constant can cause division by zero in computing BOW. self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2)) def cal_f(self): # f(a_z) is a probability distribution of word sequence a_z. # Typically f(a_z) is discounted to be less than the ML estimate so we have # some leftover probability for the z words unseen in the context (a_). # # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams # highest order N-grams n = self.ngram_order - 1 this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w, c in counts_for_hist.word_to_count.items(): counts_for_hist.word_to_f[w] = ( max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count ) # lower order N-grams for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): n_star_star = 0 for w in counts_for_hist.word_to_count.keys(): n_star_star += len(counts_for_hist.word_to_context[w]) if n_star_star != 0: for w in counts_for_hist.word_to_count.keys(): n_star_z = len(counts_for_hist.word_to_context[w]) counts_for_hist.word_to_f[w] = ( max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star ) else: # patterns begin with , they do not have "modified count", so use raw count instead for w in counts_for_hist.word_to_count.keys(): n_star_z = counts_for_hist.word_to_count[w] counts_for_hist.word_to_f[w] = ( max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count ) def cal_bow(self): # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. # Thus, two sorts of ngrams do not have a bow: # 1) highest order ngram # 2) ngrams ending in # # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) # Note that Z1 is the set of all words with c(a_z) > 0 # highest order N-grams n = self.ngram_order - 1 this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): counts_for_hist.word_to_bow[w] = None # lower order N-grams for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): if w == self.eos_symbol: counts_for_hist.word_to_bow[w] = None else: a_ = hist + (w,) assert len(a_) < self.ngram_order assert a_ in self.counts[len(a_)].keys() a_counts_for_hist = self.counts[len(a_)][a_] sum_z1_f_a_z = 0 for u in a_counts_for_hist.word_to_count.keys(): sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] sum_z1_f_z = 0 _ = a_[1:] _counts_for_hist = self.counts[len(_)][_] # Should be careful here: what is Z1 for u in a_counts_for_hist.word_to_count.keys(): sum_z1_f_z += _counts_for_hist.word_to_f[u] if sum_z1_f_z < 1: # assert sum_z1_f_a_z < 1 counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / ( 1.0 - sum_z1_f_z ) else: counts_for_hist.word_to_bow[w] = None def print_raw_counts(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) res.append( "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]) ) res.sort(reverse=True) for r in res: print(r) def print_modified_counts(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) modified_count = len(counts_for_hist.word_to_context[w]) raw_count = counts_for_hist.word_to_count[w] if modified_count == 0: res.append("{0}\t{1}".format(ngram, raw_count)) else: res.append("{0}\t{1}".format(ngram, modified_count)) res.sort(reverse=True) for r in res: print(r) def print_f(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) f = counts_for_hist.word_to_f[w] if f == 0: # f() is always 0 f = 1e-99 res.append("{0}\t{1}".format(ngram, math.log(f, 10))) res.sort(reverse=True) for r in res: print(r) def print_f_and_bow(self, info_string): # these are useful for debug. print(info_string) res = [] for this_order_counts in self.counts: for hist, counts_for_hist in this_order_counts.items(): for w in counts_for_hist.word_to_count.keys(): ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) f = counts_for_hist.word_to_f[w] if f == 0: # f() is always 0 f = 1e-99 bow = counts_for_hist.word_to_bow[w] if bow is None: res.append("{1}\t{0}".format(ngram, math.log(f, 10))) else: res.append( "{1}\t{0}\t{2}".format( ngram, math.log(f, 10), math.log(bow, 10) ) ) res.sort(reverse=True) for r in res: print(r) def print_as_arpa( self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1") ): # print as ARPA format. print("\\data\\", file=fout) for hist_len in range(self.ngram_order): # print the number of n-grams. print( "ngram {0}={1}".format( hist_len + 1, sum( [ len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values() ] ), ), file=fout, ) print("", file=fout) for hist_len in range(self.ngram_order): print("\\{0}-grams:".format(hist_len + 1), file=fout) this_order_counts = self.counts[hist_len] for hist, counts_for_hist in this_order_counts.items(): for word in counts_for_hist.word_to_count.keys(): ngram = hist + (word,) prob = counts_for_hist.word_to_f[word] bow = counts_for_hist.word_to_bow[word] if prob == 0: # f() is always 0 prob = 1e-99 line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram)) if bow is not None: line += "\t{0}".format("%.7f" % math.log10(bow)) print(line, file=fout) print("", file=fout) print("\\end\\", file=fout) if __name__ == "__main__": ngram_counts = NgramCounts(args.ngram_order) if args.text is None: ngram_counts.add_raw_counts_from_standard_input() else: assert os.path.isfile(args.text) ngram_counts.add_raw_counts_from_file(args.text) ngram_counts.cal_discounting_constants() ngram_counts.cal_f() ngram_counts.cal_bow() if args.lm is None: ngram_counts.print_as_arpa() else: with open(args.lm, "w", encoding=default_encoding) as f: ngram_counts.print_as_arpa(fout=f)