fixed a case where BOW can have problem to compute (ZeroDivisionError)

This commit is contained in:
huangruizhe 2022-01-02 15:29:50 -08:00 committed by GitHub
parent 0a67015d63
commit 82c8fac6ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -165,7 +165,7 @@ class NgramCounts:
n1 += stat[1]
n2 += stat[2]
assert n1 + 2 * n2 > 0
self.d.append(max(0.001, n1 * 1.0) / (n1 + 2 * n2)) # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2)) # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0,
# which could happen if the number of symbols is small.
# Otherwise, zero discounting constant can cause division by zero in computing BOW.
@ -243,7 +243,10 @@ class NgramCounts:
for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1
sum_z1_f_z += _counts_for_hist.word_to_f[u]
if 1.0 - sum_z1_f_z == 0:
counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
else:
counts_for_hist.word_to_bow[w] = None
def print_raw_counts(self, info_string):
# these are useful for debug.