mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Fix style check issues
Signed-off-by: Xinyuan Li <xli257@b17.clsp.jhu.edu>
This commit is contained in:
parent
7047a579b8
commit
eec59410f1
@ -7,8 +7,9 @@ It looks for manifests in the directory data/manifests.
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os, argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -82,9 +83,10 @@ def compute_fbank_slu(manifest_dir, fbanks_dir):
|
||||
)
|
||||
cut_set.to_file(cuts_file)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('manifest_dir')
|
||||
parser.add_argument('fbanks_dir')
|
||||
parser.add_argument("manifest_dir")
|
||||
parser.add_argument("fbanks_dir")
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
@ -1,12 +1,22 @@
|
||||
import pandas, argparse
|
||||
import argparse
|
||||
|
||||
import pandas
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def generate_lexicon(corpus_dir, lm_dir):
|
||||
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0)
|
||||
data = pandas.read_csv(
|
||||
str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
|
||||
)
|
||||
vocab_transcript = set()
|
||||
vocab_frames = set()
|
||||
transcripts = data['transcription'].tolist()
|
||||
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist()))
|
||||
transcripts = data["transcription"].tolist()
|
||||
frames = list(
|
||||
i
|
||||
for i in zip(
|
||||
data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
|
||||
)
|
||||
)
|
||||
|
||||
for transcript in tqdm(transcripts):
|
||||
for word in transcript.split():
|
||||
@ -14,34 +24,36 @@ def generate_lexicon(corpus_dir, lm_dir):
|
||||
|
||||
for frame in tqdm(frames):
|
||||
for word in frame:
|
||||
vocab_frames.add('_'.join(word.split()))
|
||||
vocab_frames.add("_".join(word.split()))
|
||||
|
||||
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file:
|
||||
lexicon_transcript_file.write("<UNK> 1" + '\n')
|
||||
lexicon_transcript_file.write("<s> 2" + '\n')
|
||||
lexicon_transcript_file.write("</s> 0" + '\n')
|
||||
with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
|
||||
lexicon_transcript_file.write("<UNK> 1" + "\n")
|
||||
lexicon_transcript_file.write("<s> 2" + "\n")
|
||||
lexicon_transcript_file.write("</s> 0" + "\n")
|
||||
id = 3
|
||||
for vocab in vocab_transcript:
|
||||
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n')
|
||||
lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
|
||||
id += 1
|
||||
|
||||
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file:
|
||||
lexicon_frames_file.write("<UNK> 1" + '\n')
|
||||
lexicon_frames_file.write("<s> 2" + '\n')
|
||||
lexicon_frames_file.write("</s> 0" + '\n')
|
||||
with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
|
||||
lexicon_frames_file.write("<UNK> 1" + "\n")
|
||||
lexicon_frames_file.write("<s> 2" + "\n")
|
||||
lexicon_frames_file.write("</s> 0" + "\n")
|
||||
id = 3
|
||||
for vocab in vocab_frames:
|
||||
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n')
|
||||
lexicon_frames_file.write(vocab + " " + str(id) + "\n")
|
||||
id += 1
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('corpus_dir')
|
||||
parser.add_argument('lm_dir')
|
||||
parser.add_argument("corpus_dir")
|
||||
parser.add_argument("lm_dir")
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_lexicon(args.corpus_dir, args.lm_dir)
|
||||
|
||||
|
||||
main()
|
@ -19,11 +19,11 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import argparse
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -299,8 +299,10 @@ def lexicon_to_fst(
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('lm_dir')
|
||||
parser.add_argument("lm_dir")
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
@ -328,7 +330,7 @@ def main():
|
||||
tokens.append(f"#{i}")
|
||||
|
||||
tokens = ["<eps>"] + tokens
|
||||
words = ['eps'] + words + ["#0", "!SIL"]
|
||||
words = ["eps"] + words + ["#0", "!SIL"]
|
||||
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
@ -20,7 +20,9 @@ import torch
|
||||
from transducer.model import Transducer
|
||||
|
||||
|
||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor, id2word: dict) -> List[str]:
|
||||
def greedy_search(
|
||||
model: Transducer, encoder_out: torch.Tensor, id2word: dict
|
||||
) -> List[str]:
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
|
@ -22,12 +22,12 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transducer.slu_datamodule import SluDataModule
|
||||
from transducer.beam_search import greedy_search
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.conformer import Conformer
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.joiner import Joiner
|
||||
from transducer.model import Transducer
|
||||
from transducer.slu_datamodule import SluDataModule
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
@ -45,7 +45,7 @@ def get_id2word(params):
|
||||
# 0 is blank
|
||||
id = 1
|
||||
try:
|
||||
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
|
||||
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
|
||||
for line in lexicon_file:
|
||||
if len(line.strip()) > 0:
|
||||
id2word[id] = line.split()[0]
|
||||
@ -82,11 +82,7 @@ def get_parser():
|
||||
default="transducer/exp",
|
||||
help="Directory from which to load the checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lm/frames"
|
||||
)
|
||||
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
|
||||
|
||||
return parser
|
||||
|
||||
@ -106,9 +102,11 @@ def get_params() -> AttributeDict:
|
||||
)
|
||||
|
||||
vocab_size = 1
|
||||
with open(params.lang_dir / 'lexicon_disambig.txt') as lexicon_file:
|
||||
with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
|
||||
for line in lexicon_file:
|
||||
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
||||
if (
|
||||
len(line.strip()) > 0
|
||||
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
||||
vocab_size += 1
|
||||
params.vocab_size = vocab_size
|
||||
|
||||
@ -116,10 +114,7 @@ def get_params() -> AttributeDict:
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
id2word: dict
|
||||
params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
|
||||
) -> List[List[int]]:
|
||||
"""Decode one batch and return the result in a list-of-list.
|
||||
Each sub list contains the word IDs for an utterance in the batch.
|
||||
@ -195,15 +190,18 @@ def decode_dataset(
|
||||
|
||||
results = []
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
|
||||
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
|
||||
texts = [
|
||||
" ".join(a.supervisions[0].custom["frames"])
|
||||
for a in batch["supervisions"]["cut"]
|
||||
]
|
||||
texts = [
|
||||
"<s> " + a.replace("change language", "change_language") + " </s>"
|
||||
for a in texts
|
||||
]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
id2word=id2word
|
||||
params=params, model=model, batch=batch, id2word=id2word
|
||||
)
|
||||
|
||||
this_batch = []
|
||||
@ -338,7 +336,7 @@ def main():
|
||||
model=model,
|
||||
)
|
||||
|
||||
test_set_name=str(args.feature_dir).split('/')[-2]
|
||||
test_set_name = str(args.feature_dir).split("/")[-2]
|
||||
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
|
||||
|
||||
logging.info("Done!")
|
||||
|
@ -282,11 +282,8 @@ class SluDataModule(DataModule):
|
||||
)
|
||||
return cuts_valid
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
logging.info("About to get test cuts")
|
||||
cuts_test = load_manifest_lazy(
|
||||
self.args.feature_dir / "slu_cuts_test.jsonl.gz"
|
||||
)
|
||||
cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
|
||||
return cuts_test
|
||||
|
@ -26,14 +26,15 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from slu_datamodule import SluDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from slu_datamodule import SluDataModule
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from transducer.conformer import Conformer
|
||||
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
from transducer.decoder import Decoder
|
||||
from transducer.conformer import Conformer
|
||||
from transducer.joiner import Joiner
|
||||
from transducer.model import Transducer
|
||||
|
||||
@ -49,7 +50,7 @@ def get_word2id(params):
|
||||
|
||||
# 0 is blank
|
||||
id = 1
|
||||
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
|
||||
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
|
||||
for line in lexicon_file:
|
||||
if len(line.strip()) > 0:
|
||||
word2id[line.split()[0]] = id
|
||||
@ -133,11 +134,7 @@ def get_parser():
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lm/frames"
|
||||
)
|
||||
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
|
||||
|
||||
return parser
|
||||
|
||||
@ -215,9 +212,11 @@ def get_params() -> AttributeDict:
|
||||
)
|
||||
|
||||
vocab_size = 1
|
||||
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
|
||||
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
|
||||
for line in lexicon_file:
|
||||
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
||||
if (
|
||||
len(line.strip()) > 0
|
||||
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
||||
vocab_size += 1
|
||||
params.vocab_size = vocab_size
|
||||
|
||||
@ -312,11 +311,7 @@ def save_checkpoint(
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
word2ids
|
||||
params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute RNN-T loss given the model and its inputs.
|
||||
@ -342,8 +337,14 @@ def compute_loss(
|
||||
|
||||
feature_lens = batch["supervisions"]["num_frames"].to(device)
|
||||
|
||||
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
|
||||
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
|
||||
texts = [
|
||||
" ".join(a.supervisions[0].custom["frames"])
|
||||
for a in batch["supervisions"]["cut"]
|
||||
]
|
||||
texts = [
|
||||
"<s> " + a.replace("change language", "change_language") + " </s>"
|
||||
for a in texts
|
||||
]
|
||||
labels = get_labels(texts, word2ids).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -378,7 +379,7 @@ def compute_validation_loss(
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
word2ids=word2ids
|
||||
word2ids=word2ids,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
|
||||
@ -437,11 +438,7 @@ def train_one_epoch(
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
word2ids=word2ids
|
||||
params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
|
||||
)
|
||||
# summary stats.
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
@ -471,7 +468,7 @@ def train_one_epoch(
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
word2ids=word2ids
|
||||
word2ids=word2ids,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||
@ -593,7 +590,7 @@ def run(rank, world_size, args):
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
word2ids=word2ids
|
||||
word2ids=word2ids,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
|
Loading…
x
Reference in New Issue
Block a user