mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
5dea37de03
commit
c439212733
@ -109,10 +109,11 @@ def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set LG.properties to None
|
||||
LG.__dict__["_properties"] = None
|
||||
# LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
# see https://github.com/k2-fsa/k2/pull/1140
|
||||
labels = LG.labels
|
||||
labels[labels >= first_token_disambig_id] = 0
|
||||
LG.labels = labels
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
@ -58,7 +58,6 @@ class Decoder(nn.Module):
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.unk_id = unk_id
|
||||
|
||||
Binary file not shown.
@ -24,7 +24,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Transformer
|
||||
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||
from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
@ -154,7 +154,8 @@ class Conformer(Transformer):
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
if not is_jit_tracing():
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
@ -358,6 +359,11 @@ class Conformer(Transformer):
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
if chunk_size < 0:
|
||||
# use full attention
|
||||
chunk_size = x.size(0)
|
||||
left_context = -1
|
||||
|
||||
num_left_chunks = -1
|
||||
if left_context >= 0:
|
||||
assert left_context % chunk_size == 0
|
||||
@ -763,6 +769,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
if is_jit_tracing():
|
||||
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
||||
# It assumes that the maximum input won't have more than
|
||||
# 10k frames.
|
||||
#
|
||||
# TODO(fangjun): Use torch.jit.script() for this module
|
||||
max_len = 10000
|
||||
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
@ -970,22 +984,34 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
|
||||
time2 = time1 + left_context
|
||||
if not is_jit_tracing():
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
if is_jit_tracing():
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(time2)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
x = x.reshape(-1, n)
|
||||
x = torch.gather(x, dim=1, index=indexes)
|
||||
x = x.reshape(batch_size, num_heads, time1, time2)
|
||||
return x
|
||||
else:
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
@ -1056,13 +1082,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
if not is_jit_tracing():
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
if not is_jit_tracing():
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
@ -1176,7 +1205,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
if not is_jit_tracing():
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
@ -1207,11 +1237,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
if not is_jit_tracing():
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
@ -1260,7 +1291,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
|
||||
if not is_jit_tracing():
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
|
||||
BIN
egs/tedlium3/ASR/.RESULTS.md.swp
Normal file
BIN
egs/tedlium3/ASR/.RESULTS.md.swp
Normal file
Binary file not shown.
@ -23,7 +23,7 @@ stop_stage=100
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
dl_dir=/home/work/workspace/tedlium3
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
|
||||
@ -379,18 +379,14 @@ def save_results(
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
@ -400,9 +396,7 @@ def save_results(
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
|
||||
@ -354,18 +354,14 @@ def save_results(
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
@ -375,9 +371,7 @@ def save_results(
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
|
||||
630
icefall/shared/ngram_entropy_pruning.py
Executable file
630
icefall/shared/ngram_entropy_pruning.py
Executable file
@ -0,0 +1,630 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
./ngram_entropy_pruning.py \
|
||||
-threshold 1e-8 \
|
||||
-lm download/lm/4gram.arpa \
|
||||
-write-lm download/lm/4gram_pruned_1e8.arpa
|
||||
|
||||
This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`.
|
||||
This is an implementation of ``Entropy-based Pruning of Backoff Language Models''
|
||||
in the same way as SRILM.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict, defaultdict
|
||||
from enum import Enum, unique
|
||||
from io import StringIO
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
Prune an n-gram language model based on the relative entropy
|
||||
between the original and the pruned model, based on Andreas Stolcke's paper.
|
||||
An n-gram entry is removed, if the removal causes (training set) perplexity
|
||||
of the model to increase by less than threshold relative.
|
||||
|
||||
The command takes an arpa file and a pruning threshold as input,
|
||||
and outputs a pruned arpa file.
|
||||
"""
|
||||
)
|
||||
parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram")
|
||||
parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file")
|
||||
parser.add_argument(
|
||||
"-write-lm", type=str, default=None, help="Path to output arpa file after pruning"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-minorder",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The minorder parameter limits pruning to ngrams of that length and above.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-encoding", type=str, default="utf-8", help="Encoding of the arpa file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verbose",
|
||||
type=int,
|
||||
default=2,
|
||||
choices=[0, 1, 2, 3, 4, 5],
|
||||
help="Verbose level, where 0 is most noisy; 5 is most silent",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
default_encoding = args.encoding
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s",
|
||||
level=args.verbose * 10,
|
||||
)
|
||||
|
||||
|
||||
class Context(dict):
|
||||
"""
|
||||
This class stores data for a context h.
|
||||
It behaves like a python dict object, except that it has several
|
||||
additional attributes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.log_bo = None
|
||||
|
||||
|
||||
class Arpa:
|
||||
"""
|
||||
This is a class that implement the data structure of an APRA LM.
|
||||
It (as well as some other classes) is modified based on the library
|
||||
by Stefan Fischer:
|
||||
https://github.com/sfischer13/python-arpa
|
||||
"""
|
||||
|
||||
UNK = "<unk>"
|
||||
SOS = "<s>"
|
||||
EOS = "</s>"
|
||||
FLOAT_NDIGITS = 7
|
||||
base = 10
|
||||
|
||||
@staticmethod
|
||||
def _check_input(my_input):
|
||||
if not my_input:
|
||||
raise ValueError
|
||||
elif isinstance(my_input, tuple):
|
||||
return my_input
|
||||
elif isinstance(my_input, list):
|
||||
return tuple(my_input)
|
||||
elif isinstance(my_input, str):
|
||||
return tuple(my_input.strip().split(" "))
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _check_word(input_word):
|
||||
if not isinstance(input_word, str):
|
||||
raise ValueError
|
||||
if " " in input_word:
|
||||
raise ValueError
|
||||
|
||||
def _replace_unks(self, words):
|
||||
return tuple((w if w in self else self._unk) for w in words)
|
||||
|
||||
def __init__(self, path=None, encoding=None, unk=None):
|
||||
self._counts = OrderedDict()
|
||||
self._ngrams = (
|
||||
OrderedDict()
|
||||
) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w)
|
||||
self._vocabulary = set()
|
||||
if unk is None:
|
||||
self._unk = self.UNK
|
||||
|
||||
if path is not None:
|
||||
self.loadf(path, encoding)
|
||||
|
||||
def __contains__(self, ngram):
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]
|
||||
|
||||
def contains_word(self, word):
|
||||
self._check_word(word)
|
||||
return word in self._vocabulary
|
||||
|
||||
def add_count(self, order, count):
|
||||
self._counts[order] = count
|
||||
self._ngrams[order - 1] = defaultdict(Context)
|
||||
|
||||
def update_counts(self):
|
||||
for order in range(1, self.order() + 1):
|
||||
count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()])
|
||||
if count > 0:
|
||||
self._counts[order] = count
|
||||
|
||||
def add_entry(self, ngram, p, bo=None, order=None):
|
||||
# Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3")
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
|
||||
# Note that p and bo here are in fact in the log domain (self.base = 10)
|
||||
h_context = self._ngrams[len(h)][h]
|
||||
h_context[w] = p
|
||||
if bo is not None:
|
||||
self._ngrams[len(ngram)][ngram].log_bo = bo
|
||||
|
||||
for word in ngram:
|
||||
self._vocabulary.add(word)
|
||||
|
||||
def counts(self):
|
||||
return sorted(self._counts.items())
|
||||
|
||||
def order(self):
|
||||
return max(self._counts.keys(), default=None)
|
||||
|
||||
def vocabulary(self, sort=True):
|
||||
if sort:
|
||||
return sorted(self._vocabulary)
|
||||
else:
|
||||
return self._vocabulary
|
||||
|
||||
def _entries(self, order):
|
||||
return (
|
||||
self._entry(h, w)
|
||||
for h, wlist in self._ngrams[order - 1].items()
|
||||
for w in wlist
|
||||
)
|
||||
|
||||
def _entry(self, h, w):
|
||||
# return the entry for the ngram (h, w)
|
||||
ngram = h + (w,)
|
||||
log_p = self._ngrams[len(h)][h][w]
|
||||
log_bo = self._log_bo(ngram)
|
||||
if log_bo is not None:
|
||||
return (
|
||||
round(log_p, self.FLOAT_NDIGITS),
|
||||
ngram,
|
||||
round(log_bo, self.FLOAT_NDIGITS),
|
||||
)
|
||||
else:
|
||||
return round(log_p, self.FLOAT_NDIGITS), ngram
|
||||
|
||||
def _log_bo(self, ngram):
|
||||
if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]:
|
||||
return self._ngrams[len(ngram)][ngram].log_bo
|
||||
else:
|
||||
return None
|
||||
|
||||
def _log_p(self, ngram):
|
||||
h = ngram[:-1] # h is a tuple
|
||||
w = ngram[-1] # w is a string/word
|
||||
if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]:
|
||||
return self._ngrams[len(h)][h][w]
|
||||
else:
|
||||
return None
|
||||
|
||||
def log_p_raw(self, ngram):
|
||||
log_p = self._log_p(ngram)
|
||||
if log_p is not None:
|
||||
return log_p
|
||||
else:
|
||||
if len(ngram) == 1:
|
||||
raise KeyError
|
||||
else:
|
||||
log_bo = self._log_bo(ngram[:-1])
|
||||
if log_bo is None:
|
||||
log_bo = 0
|
||||
return log_bo + self.log_p_raw(ngram[1:])
|
||||
|
||||
def log_joint_prob(self, sequence):
|
||||
# Compute the joint prob of the sequence based on the chain rule
|
||||
# Note that sequence should be a tuple of strings
|
||||
#
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527
|
||||
|
||||
log_joint_p = 0
|
||||
seq = sequence
|
||||
while len(seq) > 0:
|
||||
log_joint_p += self.log_p_raw(seq)
|
||||
seq = seq[:-1]
|
||||
|
||||
# If we're computing the marginal probability of the unigram
|
||||
# <s> context we have to look up </s> instead since the former
|
||||
# has prob = 0.
|
||||
if len(seq) == 1 and seq[0] == self.SOS:
|
||||
seq = (self.EOS,)
|
||||
|
||||
return log_joint_p
|
||||
|
||||
def set_new_context(self, h):
|
||||
old_context = self._ngrams[len(h)][h]
|
||||
self._ngrams[len(h)][h] = Context()
|
||||
return old_context
|
||||
|
||||
def log_p(self, ngram):
|
||||
words = self._check_input(ngram)
|
||||
if self._unk:
|
||||
words = self._replace_unks(words)
|
||||
return self.log_p_raw(words)
|
||||
|
||||
def log_s(self, sentence, sos=SOS, eos=EOS):
|
||||
words = self._check_input(sentence)
|
||||
if self._unk:
|
||||
words = self._replace_unks(words)
|
||||
if sos:
|
||||
words = (sos,) + words
|
||||
if eos:
|
||||
words = words + (eos,)
|
||||
result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1))
|
||||
if sos:
|
||||
result = result - self.log_p_raw(words[:1])
|
||||
return result
|
||||
|
||||
def p(self, ngram):
|
||||
return self.base ** self.log_p(ngram)
|
||||
|
||||
def s(self, sentence):
|
||||
return self.base ** self.log_s(sentence)
|
||||
|
||||
def write(self, fp):
|
||||
fp.write("\n\\data\\\n")
|
||||
for order, count in self.counts():
|
||||
fp.write("ngram {}={}\n".format(order, count))
|
||||
fp.write("\n")
|
||||
for order, _ in self.counts():
|
||||
fp.write("\\{}-grams:\n".format(order))
|
||||
for e in self._entries(order):
|
||||
prob = e[0]
|
||||
ngram = " ".join(e[1])
|
||||
if len(e) == 2:
|
||||
fp.write("{}\t{}\n".format(prob, ngram))
|
||||
elif len(e) == 3:
|
||||
backoff = e[2]
|
||||
fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff))
|
||||
else:
|
||||
raise ValueError
|
||||
fp.write("\n")
|
||||
fp.write("\\end\\\n")
|
||||
|
||||
|
||||
class ArpaParser:
|
||||
"""
|
||||
This is a class that implement a parser of an arpa file
|
||||
"""
|
||||
|
||||
@unique
|
||||
class State(Enum):
|
||||
DATA = 1
|
||||
COUNT = 2
|
||||
HEADER = 3
|
||||
ENTRY = 4
|
||||
|
||||
re_count = re.compile(r"^ngram (\d+)=(\d+)$")
|
||||
re_header = re.compile(r"^\\(\d+)-grams:$")
|
||||
re_entry = re.compile(
|
||||
"^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)"
|
||||
"\t"
|
||||
"(\\S+( \\S+)*)"
|
||||
"(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$"
|
||||
)
|
||||
|
||||
def _parse(self, fp):
|
||||
self._result = []
|
||||
self._state = self.State.DATA
|
||||
self._tmp_model = None
|
||||
self._tmp_order = None
|
||||
for line in fp:
|
||||
line = line.strip()
|
||||
if self._state == self.State.DATA:
|
||||
self._data(line)
|
||||
elif self._state == self.State.COUNT:
|
||||
self._count(line)
|
||||
elif self._state == self.State.HEADER:
|
||||
self._header(line)
|
||||
elif self._state == self.State.ENTRY:
|
||||
self._entry(line)
|
||||
if self._state != self.State.DATA:
|
||||
raise Exception(line)
|
||||
return self._result
|
||||
|
||||
def _data(self, line):
|
||||
if line == "\\data\\":
|
||||
self._state = self.State.COUNT
|
||||
self._tmp_model = Arpa()
|
||||
else:
|
||||
pass # skip comment line
|
||||
|
||||
def _count(self, line):
|
||||
match = self.re_count.match(line)
|
||||
if match:
|
||||
order = match.group(1)
|
||||
count = match.group(2)
|
||||
self._tmp_model.add_count(int(order), int(count))
|
||||
elif not line:
|
||||
self._state = self.State.HEADER # there are no counts
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
def _header(self, line):
|
||||
match = self.re_header.match(line)
|
||||
if match:
|
||||
self._state = self.State.ENTRY
|
||||
self._tmp_order = int(match.group(1))
|
||||
elif line == "\\end\\":
|
||||
self._result.append(self._tmp_model)
|
||||
self._state = self.State.DATA
|
||||
self._tmp_model = None
|
||||
self._tmp_order = None
|
||||
elif not line:
|
||||
pass # skip empty line
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
def _entry(self, line):
|
||||
match = self.re_entry.match(line)
|
||||
if match:
|
||||
p = self._float_or_int(match.group(1))
|
||||
ngram = tuple(match.group(4).split(" "))
|
||||
bo_match = match.group(7)
|
||||
bo = self._float_or_int(bo_match) if bo_match else None
|
||||
self._tmp_model.add_entry(ngram, p, bo, self._tmp_order)
|
||||
elif not line:
|
||||
self._state = self.State.HEADER # last entry
|
||||
else:
|
||||
raise Exception(line)
|
||||
|
||||
@staticmethod
|
||||
def _float_or_int(s):
|
||||
f = float(s)
|
||||
i = int(f)
|
||||
if str(i) == s: # don't drop trailing ".0"
|
||||
return i
|
||||
else:
|
||||
return f
|
||||
|
||||
def load(self, fp):
|
||||
"""Deserialize fp (a file-like object) to a Python object."""
|
||||
return self._parse(fp)
|
||||
|
||||
def loadf(self, path, encoding=None):
|
||||
"""Deserialize path (.arpa, .gz) to a Python object."""
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
with gzip.open(path, mode="rt", encoding=encoding) as f:
|
||||
return self.load(f)
|
||||
else:
|
||||
with open(path, mode="rt", encoding=encoding) as f:
|
||||
return self.load(f)
|
||||
|
||||
def loads(self, s):
|
||||
"""Deserialize s (a str) to a Python object."""
|
||||
with StringIO(s) as f:
|
||||
return self.load(f)
|
||||
|
||||
def dump(self, obj, fp):
|
||||
"""Serialize obj to fp (a file-like object) in ARPA format."""
|
||||
obj.write(fp)
|
||||
|
||||
def dumpf(self, obj, path, encoding=None):
|
||||
"""Serialize obj to path in ARPA format (.arpa, .gz)."""
|
||||
path = str(path)
|
||||
if path.endswith(".gz"):
|
||||
with gzip.open(path, mode="wt", encoding=encoding) as f:
|
||||
return self.dump(obj, f)
|
||||
else:
|
||||
with open(path, mode="wt", encoding=encoding) as f:
|
||||
self.dump(obj, f)
|
||||
|
||||
def dumps(self, obj):
|
||||
"""Serialize obj to an ARPA formatted str."""
|
||||
with StringIO() as f:
|
||||
self.dump(obj, f)
|
||||
return f.getvalue()
|
||||
|
||||
|
||||
def add_log_p(prev_log_sum, log_p, base):
|
||||
return math.log(base**log_p + base**prev_log_sum, base)
|
||||
|
||||
|
||||
def compute_numerator_denominator(lm, h):
|
||||
log_sum_seen_h = -math.inf
|
||||
log_sum_seen_h_lower = -math.inf
|
||||
base = lm.base
|
||||
for w, log_p in lm._ngrams[len(h)][h].items():
|
||||
log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base)
|
||||
|
||||
ngram = h + (w,)
|
||||
log_p_lower = lm.log_p_raw(ngram[1:])
|
||||
log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base)
|
||||
|
||||
numerator = 1.0 - base**log_sum_seen_h
|
||||
denominator = 1.0 - base**log_sum_seen_h_lower
|
||||
return numerator, denominator
|
||||
|
||||
|
||||
def prune(lm, threshold, minorder):
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330
|
||||
|
||||
for i in range(
|
||||
lm.order(), max(minorder - 1, 1), -1
|
||||
): # i is the order of the ngram (h, w)
|
||||
logging.info("processing %d-grams ..." % i)
|
||||
count_pruned_ngrams = 0
|
||||
|
||||
h_dict = lm._ngrams[i - 1]
|
||||
for h in list(h_dict.keys()):
|
||||
# old backoff weight, BOW(h)
|
||||
log_bow = lm._log_bo(h)
|
||||
if log_bow is None:
|
||||
log_bow = 0
|
||||
|
||||
# Compute numerator and denominator of the backoff weight,
|
||||
# so that we can quickly compute the BOW adjustment due to
|
||||
# leaving out one prob.
|
||||
numerator, denominator = compute_numerator_denominator(lm, h)
|
||||
|
||||
# assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5
|
||||
|
||||
# Compute the marginal probability of the context, P(h)
|
||||
h_log_p = lm.log_joint_prob(h)
|
||||
|
||||
all_pruned = True
|
||||
pruned_w_set = set()
|
||||
|
||||
for w, log_p in h_dict[h].items():
|
||||
ngram = h + (w,)
|
||||
|
||||
# lower-order estimate for ngramProb, P(w|h')
|
||||
backoff_prob = lm.log_p_raw(ngram[1:])
|
||||
|
||||
# Compute BOW after removing ngram, BOW'(h)
|
||||
new_log_bow = math.log(
|
||||
numerator + lm.base**log_p, lm.base
|
||||
) - math.log(denominator + lm.base**backoff_prob, lm.base)
|
||||
|
||||
# Compute change in entropy due to removal of ngram
|
||||
delta_prob = backoff_prob + new_log_bow - log_p
|
||||
delta_entropy = -(lm.base**h_log_p) * (
|
||||
(lm.base**log_p) * delta_prob
|
||||
+ numerator * (new_log_bow - log_bow)
|
||||
)
|
||||
|
||||
# compute relative change in model (training set) perplexity
|
||||
perp_change = lm.base**delta_entropy - 1.0
|
||||
|
||||
pruned = threshold > 0 and perp_change < threshold
|
||||
|
||||
# Make sure we don't prune ngrams whose backoff nodes are needed
|
||||
if (
|
||||
pruned
|
||||
and len(ngram) in lm._ngrams
|
||||
and len(lm._ngrams[len(ngram)][ngram]) > 0
|
||||
):
|
||||
pruned = False
|
||||
|
||||
logging.debug(
|
||||
"CONTEXT "
|
||||
+ str(h)
|
||||
+ " WORD "
|
||||
+ w
|
||||
+ " CONTEXTPROB %f " % h_log_p
|
||||
+ " OLDPROB %f " % log_p
|
||||
+ " NEWPROB %f " % (backoff_prob + new_log_bow)
|
||||
+ " DELTA-H %f " % delta_entropy
|
||||
+ " DELTA-LOGP %f " % delta_prob
|
||||
+ " PPL-CHANGE %f " % perp_change
|
||||
+ " PRUNED "
|
||||
+ str(pruned)
|
||||
)
|
||||
|
||||
if pruned:
|
||||
pruned_w_set.add(w)
|
||||
count_pruned_ngrams += 1
|
||||
else:
|
||||
all_pruned = False
|
||||
|
||||
# If we removed all ngrams for this context we can
|
||||
# remove the context itself, but only if the present
|
||||
# context is not a prefix to a longer one.
|
||||
if all_pruned and len(pruned_w_set) == len(h_dict[h]):
|
||||
del h_dict[
|
||||
h
|
||||
] # this context h is no longer needed, as its ngram prob is stored at its own context h'
|
||||
elif len(pruned_w_set) > 0:
|
||||
# The pruning for this context h is actually done here
|
||||
old_context = lm.set_new_context(h)
|
||||
|
||||
for w, p_w in old_context.items():
|
||||
if w not in pruned_w_set:
|
||||
lm.add_entry(
|
||||
h + (w,), p_w
|
||||
) # the entry hw is stored at the context h
|
||||
|
||||
# We need to recompute the back-off weight, but
|
||||
# this can only be done after completing the pruning
|
||||
# of the lower-order ngrams.
|
||||
# Reference:
|
||||
# https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124
|
||||
|
||||
logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i))
|
||||
|
||||
# recompute backoff weights
|
||||
for i in range(
|
||||
max(minorder - 1, 1) + 1, lm.order() + 1
|
||||
): # be careful of this order: from low- to high-order
|
||||
for h in lm._ngrams[i - 1]:
|
||||
numerator, denominator = compute_numerator_denominator(lm, h)
|
||||
new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base)
|
||||
lm._ngrams[len(h)][h].log_bo = new_log_bow
|
||||
|
||||
# update counts
|
||||
lm.update_counts()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def check_h_is_valid(lm, h):
|
||||
sum_under_h = sum(
|
||||
[lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)]
|
||||
)
|
||||
if abs(sum_under_h - 1.0) > 1e-6:
|
||||
logging.info("warning: %s %f" % (str(h), sum_under_h))
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def validate_lm(lm):
|
||||
# sanity check if the conditional probability sums to one under each context h
|
||||
for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w)
|
||||
logging.info("validating %d-grams ..." % i)
|
||||
h_dict = lm._ngrams[i - 1]
|
||||
for h in h_dict.keys():
|
||||
check_h_is_valid(lm, h)
|
||||
|
||||
|
||||
def compare_two_apras(path1, path2):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load an arpa file
|
||||
logging.info("Loading the arpa file from %s" % args.lm)
|
||||
parser = ArpaParser()
|
||||
models = parser.loadf(args.lm, encoding=default_encoding)
|
||||
lm = models[0] # ARPA files may contain several models.
|
||||
logging.info("Stats before pruning:")
|
||||
for i, cnt in lm.counts():
|
||||
logging.info("ngram %d=%d" % (i, cnt))
|
||||
|
||||
# prune it, the language model will be modified in-place
|
||||
logging.info("Start pruning the model with threshold=%.3E..." % args.threshold)
|
||||
prune(lm, args.threshold, args.minorder)
|
||||
|
||||
# validate_lm(lm)
|
||||
|
||||
# write the arpa language model to a file
|
||||
logging.info("Stats after pruning:")
|
||||
for i, cnt in lm.counts():
|
||||
logging.info("ngram %d=%d" % (i, cnt))
|
||||
logging.info("Saving the pruned arpa file to %s" % args.write_lm)
|
||||
parser.dumpf(lm, args.write_lm, encoding=default_encoding)
|
||||
logging.info("Done.")
|
||||
Loading…
x
Reference in New Issue
Block a user