Fix style check issues

Signed-off-by: Xinyuan Li <xli257@b17.clsp.jhu.edu>
This commit is contained in:
Xinyuan Li 2024-01-24 11:45:16 -05:00
parent 7047a579b8
commit eec59410f1
7 changed files with 133 additions and 123 deletions

View File

@ -7,8 +7,9 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank. The generated fbank features are saved in data/fbank.
""" """
import argparse
import logging import logging
import os, argparse import os
from pathlib import Path from pathlib import Path
import torch import torch
@ -82,9 +83,10 @@ def compute_fbank_slu(manifest_dir, fbanks_dir):
) )
cut_set.to_file(cuts_file) cut_set.to_file(cuts_file)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('manifest_dir') parser.add_argument("manifest_dir")
parser.add_argument('fbanks_dir') parser.add_argument("fbanks_dir")
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -1,12 +1,22 @@
import pandas, argparse import argparse
import pandas
from tqdm import tqdm from tqdm import tqdm
def generate_lexicon(corpus_dir, lm_dir): def generate_lexicon(corpus_dir, lm_dir):
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0) data = pandas.read_csv(
str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
)
vocab_transcript = set() vocab_transcript = set()
vocab_frames = set() vocab_frames = set()
transcripts = data['transcription'].tolist() transcripts = data["transcription"].tolist()
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist())) frames = list(
i
for i in zip(
data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
)
)
for transcript in tqdm(transcripts): for transcript in tqdm(transcripts):
for word in transcript.split(): for word in transcript.split():
@ -14,34 +24,36 @@ def generate_lexicon(corpus_dir, lm_dir):
for frame in tqdm(frames): for frame in tqdm(frames):
for word in frame: for word in frame:
vocab_frames.add('_'.join(word.split())) vocab_frames.add("_".join(word.split()))
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file: with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + '\n') lexicon_transcript_file.write("<UNK> 1" + "\n")
lexicon_transcript_file.write("<s> 2" + '\n') lexicon_transcript_file.write("<s> 2" + "\n")
lexicon_transcript_file.write("</s> 0" + '\n') lexicon_transcript_file.write("</s> 0" + "\n")
id = 3 id = 3
for vocab in vocab_transcript: for vocab in vocab_transcript:
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n') lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
id += 1 id += 1
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file: with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + '\n') lexicon_frames_file.write("<UNK> 1" + "\n")
lexicon_frames_file.write("<s> 2" + '\n') lexicon_frames_file.write("<s> 2" + "\n")
lexicon_frames_file.write("</s> 0" + '\n') lexicon_frames_file.write("</s> 0" + "\n")
id = 3 id = 3
for vocab in vocab_frames: for vocab in vocab_frames:
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n') lexicon_frames_file.write(vocab + " " + str(id) + "\n")
id += 1 id += 1
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('corpus_dir') parser.add_argument("corpus_dir")
parser.add_argument('lm_dir') parser.add_argument("lm_dir")
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
generate_lexicon(args.corpus_dir, args.lm_dir) generate_lexicon(args.corpus_dir, args.lm_dir)
main()
main()

View File

@ -19,11 +19,11 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.
""" """
import argparse
import math import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import argparse
import k2 import k2
import torch import torch
@ -299,8 +299,10 @@ def lexicon_to_fst(
fsa = k2.Fsa.from_str(arcs, acceptor=False) fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa return fsa
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('lm_dir') parser.add_argument("lm_dir")
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
@ -312,58 +314,58 @@ def main():
sil_prob = 0.5 sil_prob = 0.5
for name, lexicon_filename in zip(names, lexicon_filenames): for name, lexicon_filename in zip(names, lexicon_filenames):
lexicon = read_lexicon(lexicon_filename) lexicon = read_lexicon(lexicon_filename)
tokens = get_words(lexicon) tokens = get_words(lexicon)
words = get_words(lexicon) words = get_words(lexicon)
new_lexicon = [] new_lexicon = []
for lexicon_item in lexicon: for lexicon_item in lexicon:
new_lexicon.append((lexicon_item[0], [lexicon_item[0]])) new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
lexicon = new_lexicon lexicon = new_lexicon
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1): for i in range(max_disambig + 1):
disambig = f"#{i}" disambig = f"#{i}"
assert disambig not in tokens assert disambig not in tokens
tokens.append(f"#{i}") tokens.append(f"#{i}")
tokens = ["<eps>"] + tokens tokens = ["<eps>"] + tokens
words = ['eps'] + words + ["#0", "!SIL"] words = ["eps"] + words + ["#0", "!SIL"]
token2id = generate_id_map(tokens) token2id = generate_id_map(tokens)
word2id = generate_id_map(words) word2id = generate_id_map(words)
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id) write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
write_mapping(out_dir / ("words_" + name + ".txt"), word2id) write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig) write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
L = lexicon_to_fst( L = lexicon_to_fst(
lexicon, lexicon,
token2id=word2id, token2id=word2id,
word2id=word2id, word2id=word2id,
sil_token=sil_token, sil_token=sil_token,
sil_prob=sil_prob, sil_prob=sil_prob,
) )
L_disambig = lexicon_to_fst( L_disambig = lexicon_to_fst(
lexicon_disambig, lexicon_disambig,
token2id=word2id, token2id=word2id,
word2id=word2id, word2id=word2id,
sil_token=sil_token, sil_token=sil_token,
sil_prob=sil_prob, sil_prob=sil_prob,
need_self_loops=True, need_self_loops=True,
) )
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt")) torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt")) torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
if False: if False:
# Just for debugging, will remove it # Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L") L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
main() main()

View File

@ -20,7 +20,9 @@ import torch
from transducer.model import Transducer from transducer.model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor, id2word: dict) -> List[str]: def greedy_search(
model: Transducer, encoder_out: torch.Tensor, id2word: dict
) -> List[str]:
""" """
Args: Args:
model: model:

View File

@ -22,12 +22,12 @@ from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transducer.slu_datamodule import SluDataModule
from transducer.beam_search import greedy_search from transducer.beam_search import greedy_search
from transducer.decoder import Decoder
from transducer.conformer import Conformer from transducer.conformer import Conformer
from transducer.decoder import Decoder
from transducer.joiner import Joiner from transducer.joiner import Joiner
from transducer.model import Transducer from transducer.model import Transducer
from transducer.slu_datamodule import SluDataModule
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
@ -45,7 +45,7 @@ def get_id2word(params):
# 0 is blank # 0 is blank
id = 1 id = 1
try: try:
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file: for line in lexicon_file:
if len(line.strip()) > 0: if len(line.strip()) > 0:
id2word[id] = line.split()[0] id2word[id] = line.split()[0]
@ -82,11 +82,7 @@ def get_parser():
default="transducer/exp", default="transducer/exp",
help="Directory from which to load the checkpoints", help="Directory from which to load the checkpoints",
) )
parser.add_argument( parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
"--lang-dir",
type=str,
default="data/lm/frames"
)
return parser return parser
@ -106,9 +102,11 @@ def get_params() -> AttributeDict:
) )
vocab_size = 1 vocab_size = 1
with open(params.lang_dir / 'lexicon_disambig.txt') as lexicon_file: with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file: for line in lexicon_file:
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line: if (
len(line.strip()) > 0
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1 vocab_size += 1
params.vocab_size = vocab_size params.vocab_size = vocab_size
@ -116,10 +114,7 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
model: nn.Module,
batch: dict,
id2word: dict
) -> List[List[int]]: ) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list. """Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch. Each sub list contains the word IDs for an utterance in the batch.
@ -195,15 +190,18 @@ def decode_dataset(
results = [] results = []
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]] texts = [
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts] " ".join(a.supervisions[0].custom["frames"])
for a in batch["supervisions"]["cut"]
]
texts = [
"<s> " + a.replace("change language", "change_language") + " </s>"
for a in texts
]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch( hyps = decode_one_batch(
params=params, params=params, model=model, batch=batch, id2word=id2word
model=model,
batch=batch,
id2word=id2word
) )
this_batch = [] this_batch = []
@ -338,7 +336,7 @@ def main():
model=model, model=model,
) )
test_set_name=str(args.feature_dir).split('/')[-2] test_set_name = str(args.feature_dir).split("/")[-2]
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results) save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
logging.info("Done!") logging.info("Done!")

View File

@ -282,11 +282,8 @@ class SluDataModule(DataModule):
) )
return cuts_valid return cuts_valid
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
cuts_test = load_manifest_lazy( cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
self.args.feature_dir / "slu_cuts_test.jsonl.gz"
)
return cuts_test return cuts_test

View File

@ -26,14 +26,15 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from slu_datamodule import SluDataModule
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from slu_datamodule import SluDataModule
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from transducer.conformer import Conformer
# from torch.utils.tensorboard import SummaryWriter # from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder from transducer.decoder import Decoder
from transducer.conformer import Conformer
from transducer.joiner import Joiner from transducer.joiner import Joiner
from transducer.model import Transducer from transducer.model import Transducer
@ -49,20 +50,20 @@ def get_word2id(params):
# 0 is blank # 0 is blank
id = 1 id = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file: for line in lexicon_file:
if len(line.strip()) > 0: if len(line.strip()) > 0:
word2id[line.split()[0]] = id word2id[line.split()[0]] = id
id += 1 id += 1
return word2id return word2id
def get_labels(texts: List[str], word2id) -> k2.RaggedTensor: def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
""" """
Args: Args:
texts: texts:
A list of transcripts. A list of transcripts.
Returns: Returns:
Return a ragged tensor containing the corresponding word ID. Return a ragged tensor containing the corresponding word ID.
""" """
@ -133,11 +134,7 @@ def get_parser():
help="The seed for random generators intended for reproducibility", help="The seed for random generators intended for reproducibility",
) )
parser.add_argument( parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
"--lang-dir",
type=str,
default="data/lm/frames"
)
return parser return parser
@ -215,9 +212,11 @@ def get_params() -> AttributeDict:
) )
vocab_size = 1 vocab_size = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file: for line in lexicon_file:
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line: if (
len(line.strip()) > 0
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1 vocab_size += 1
params.vocab_size = vocab_size params.vocab_size = vocab_size
@ -312,11 +311,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
model: nn.Module,
batch: dict,
is_training: bool,
word2ids
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute RNN-T loss given the model and its inputs. Compute RNN-T loss given the model and its inputs.
@ -342,8 +337,14 @@ def compute_loss(
feature_lens = batch["supervisions"]["num_frames"].to(device) feature_lens = batch["supervisions"]["num_frames"].to(device)
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]] texts = [
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts] " ".join(a.supervisions[0].custom["frames"])
for a in batch["supervisions"]["cut"]
]
texts = [
"<s> " + a.replace("change language", "change_language") + " </s>"
for a in texts
]
labels = get_labels(texts, word2ids).to(device) labels = get_labels(texts, word2ids).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
@ -378,7 +379,7 @@ def compute_validation_loss(
model=model, model=model,
batch=batch, batch=batch,
is_training=False, is_training=False,
word2ids=word2ids word2ids=word2ids,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
@ -437,11 +438,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
model=model,
batch=batch,
is_training=True,
word2ids=word2ids
) )
# summary stats. # summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -471,7 +468,7 @@ def train_one_epoch(
model=model, model=model,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
word2ids=word2ids word2ids=word2ids,
) )
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
@ -593,7 +590,7 @@ def run(rank, world_size, args):
valid_dl=valid_dl, valid_dl=valid_dl,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
word2ids=word2ids word2ids=word2ids,
) )
save_checkpoint( save_checkpoint(