Fix style check issues

Signed-off-by: Xinyuan Li <xli257@b17.clsp.jhu.edu>
This commit is contained in:
Xinyuan Li 2024-01-24 11:45:16 -05:00
parent 7047a579b8
commit eec59410f1
7 changed files with 133 additions and 123 deletions

View File

@ -7,8 +7,9 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os, argparse
import os
from pathlib import Path
import torch
@ -82,9 +83,10 @@ def compute_fbank_slu(manifest_dir, fbanks_dir):
)
cut_set.to_file(cuts_file)
parser = argparse.ArgumentParser()
parser.add_argument('manifest_dir')
parser.add_argument('fbanks_dir')
parser.add_argument("manifest_dir")
parser.add_argument("fbanks_dir")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -1,12 +1,22 @@
import pandas, argparse
import argparse
import pandas
from tqdm import tqdm
def generate_lexicon(corpus_dir, lm_dir):
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0)
data = pandas.read_csv(
str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
)
vocab_transcript = set()
vocab_frames = set()
transcripts = data['transcription'].tolist()
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist()))
transcripts = data["transcription"].tolist()
frames = list(
i
for i in zip(
data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
)
)
for transcript in tqdm(transcripts):
for word in transcript.split():
@ -14,34 +24,36 @@ def generate_lexicon(corpus_dir, lm_dir):
for frame in tqdm(frames):
for word in frame:
vocab_frames.add('_'.join(word.split()))
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + '\n')
lexicon_transcript_file.write("<s> 2" + '\n')
lexicon_transcript_file.write("</s> 0" + '\n')
vocab_frames.add("_".join(word.split()))
with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
lexicon_transcript_file.write("<UNK> 1" + "\n")
lexicon_transcript_file.write("<s> 2" + "\n")
lexicon_transcript_file.write("</s> 0" + "\n")
id = 3
for vocab in vocab_transcript:
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n')
lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
id += 1
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + '\n')
lexicon_frames_file.write("<s> 2" + '\n')
lexicon_frames_file.write("</s> 0" + '\n')
with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
lexicon_frames_file.write("<UNK> 1" + "\n")
lexicon_frames_file.write("<s> 2" + "\n")
lexicon_frames_file.write("</s> 0" + "\n")
id = 3
for vocab in vocab_frames:
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n')
lexicon_frames_file.write(vocab + " " + str(id) + "\n")
id += 1
parser = argparse.ArgumentParser()
parser.add_argument('corpus_dir')
parser.add_argument('lm_dir')
parser.add_argument("corpus_dir")
parser.add_argument("lm_dir")
def main():
args = parser.parse_args()
generate_lexicon(args.corpus_dir, args.lm_dir)
main()
main()

View File

@ -19,11 +19,11 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import argparse
import k2
import torch
@ -299,8 +299,10 @@ def lexicon_to_fst(
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
parser = argparse.ArgumentParser()
parser.add_argument('lm_dir')
parser.add_argument("lm_dir")
def main():
args = parser.parse_args()
@ -312,58 +314,58 @@ def main():
sil_prob = 0.5
for name, lexicon_filename in zip(names, lexicon_filenames):
lexicon = read_lexicon(lexicon_filename)
tokens = get_words(lexicon)
words = get_words(lexicon)
new_lexicon = []
for lexicon_item in lexicon:
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
lexicon = new_lexicon
lexicon = read_lexicon(lexicon_filename)
tokens = get_words(lexicon)
words = get_words(lexicon)
new_lexicon = []
for lexicon_item in lexicon:
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
lexicon = new_lexicon
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
tokens = ["<eps>"] + tokens
words = ['eps'] + words + ["#0", "!SIL"]
tokens = ["<eps>"] + tokens
words = ["eps"] + words + ["#0", "!SIL"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L = lexicon_to_fst(
lexicon,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=word2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(out_dir / "L.png", title="L")
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
main()

View File

@ -20,7 +20,9 @@ import torch
from transducer.model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor, id2word: dict) -> List[str]:
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, id2word: dict
) -> List[str]:
"""
Args:
model:

View File

@ -22,12 +22,12 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from transducer.slu_datamodule import SluDataModule
from transducer.beam_search import greedy_search
from transducer.decoder import Decoder
from transducer.conformer import Conformer
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from transducer.slu_datamodule import SluDataModule
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
@ -45,7 +45,7 @@ def get_id2word(params):
# 0 is blank
id = 1
try:
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
id2word[id] = line.split()[0]
@ -82,11 +82,7 @@ def get_parser():
default="transducer/exp",
help="Directory from which to load the checkpoints",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lm/frames"
)
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
return parser
@ -106,9 +102,11 @@ def get_params() -> AttributeDict:
)
vocab_size = 1
with open(params.lang_dir / 'lexicon_disambig.txt') as lexicon_file:
with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
if (
len(line.strip()) > 0
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1
params.vocab_size = vocab_size
@ -116,10 +114,7 @@ def get_params() -> AttributeDict:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
id2word: dict
params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
@ -195,15 +190,18 @@ def decode_dataset(
results = []
for batch_idx, batch in enumerate(dl):
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
texts = [
" ".join(a.supervisions[0].custom["frames"])
for a in batch["supervisions"]["cut"]
]
texts = [
"<s> " + a.replace("change language", "change_language") + " </s>"
for a in texts
]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps = decode_one_batch(
params=params,
model=model,
batch=batch,
id2word=id2word
params=params, model=model, batch=batch, id2word=id2word
)
this_batch = []
@ -338,7 +336,7 @@ def main():
model=model,
)
test_set_name=str(args.feature_dir).split('/')[-2]
test_set_name = str(args.feature_dir).split("/")[-2]
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
logging.info("Done!")

View File

@ -282,11 +282,8 @@ class SluDataModule(DataModule):
)
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
cuts_test = load_manifest_lazy(
self.args.feature_dir / "slu_cuts_test.jsonl.gz"
)
cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
return cuts_test

View File

@ -26,14 +26,15 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from slu_datamodule import SluDataModule
from lhotse.utils import fix_random_seed
from slu_datamodule import SluDataModule
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from transducer.conformer import Conformer
# from torch.utils.tensorboard import SummaryWriter
from transducer.decoder import Decoder
from transducer.conformer import Conformer
from transducer.joiner import Joiner
from transducer.model import Transducer
@ -49,20 +50,20 @@ def get_word2id(params):
# 0 is blank
id = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
word2id[line.split()[0]] = id
id += 1
return word2id
return word2id
def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
"""
Args:
texts:
A list of transcripts.
A list of transcripts.
Returns:
Return a ragged tensor containing the corresponding word ID.
"""
@ -133,11 +134,7 @@ def get_parser():
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lm/frames"
)
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
return parser
@ -215,9 +212,11 @@ def get_params() -> AttributeDict:
)
vocab_size = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
if (
len(line.strip()) > 0
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
vocab_size += 1
params.vocab_size = vocab_size
@ -312,11 +311,7 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
is_training: bool,
word2ids
params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute RNN-T loss given the model and its inputs.
@ -342,8 +337,14 @@ def compute_loss(
feature_lens = batch["supervisions"]["num_frames"].to(device)
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
texts = [
" ".join(a.supervisions[0].custom["frames"])
for a in batch["supervisions"]["cut"]
]
texts = [
"<s> " + a.replace("change language", "change_language") + " </s>"
for a in texts
]
labels = get_labels(texts, word2ids).to(device)
with torch.set_grad_enabled(is_training):
@ -378,7 +379,7 @@ def compute_validation_loss(
model=model,
batch=batch,
is_training=False,
word2ids=word2ids
word2ids=word2ids,
)
assert loss.requires_grad is False
@ -437,11 +438,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
is_training=True,
word2ids=word2ids
params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
)
# summary stats.
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -471,7 +468,7 @@ def train_one_epoch(
model=model,
valid_dl=valid_dl,
world_size=world_size,
word2ids=word2ids
word2ids=word2ids,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
@ -593,7 +590,7 @@ def run(rank, world_size, args):
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
word2ids=word2ids
word2ids=word2ids,
)
save_checkpoint(