diff --git a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py index ce3a3aaf8..a51b7b47b 100755 --- a/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py +++ b/egs/fluent_speech_commands/SLU/local/compute_fbank_slu.py @@ -7,8 +7,9 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ +import argparse import logging -import os, argparse +import os from pathlib import Path import torch @@ -82,9 +83,10 @@ def compute_fbank_slu(manifest_dir, fbanks_dir): ) cut_set.to_file(cuts_file) + parser = argparse.ArgumentParser() -parser.add_argument('manifest_dir') -parser.add_argument('fbanks_dir') +parser.add_argument("manifest_dir") +parser.add_argument("fbanks_dir") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py index 6613e4217..6263e062f 100755 --- a/egs/fluent_speech_commands/SLU/local/generate_lexicon.py +++ b/egs/fluent_speech_commands/SLU/local/generate_lexicon.py @@ -1,12 +1,22 @@ -import pandas, argparse +import argparse + +import pandas from tqdm import tqdm + def generate_lexicon(corpus_dir, lm_dir): - data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0) + data = pandas.read_csv( + str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0 + ) vocab_transcript = set() vocab_frames = set() - transcripts = data['transcription'].tolist() - frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist())) + transcripts = data["transcription"].tolist() + frames = list( + i + for i in zip( + data["action"].tolist(), data["object"].tolist(), data["location"].tolist() + ) + ) for transcript in tqdm(transcripts): for word in transcript.split(): @@ -14,34 +24,36 @@ def generate_lexicon(corpus_dir, lm_dir): for frame in tqdm(frames): for word in frame: - vocab_frames.add('_'.join(word.split())) - - with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file: - lexicon_transcript_file.write(" 1" + '\n') - lexicon_transcript_file.write(" 2" + '\n') - lexicon_transcript_file.write(" 0" + '\n') + vocab_frames.add("_".join(word.split())) + + with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file: + lexicon_transcript_file.write(" 1" + "\n") + lexicon_transcript_file.write(" 2" + "\n") + lexicon_transcript_file.write(" 0" + "\n") id = 3 for vocab in vocab_transcript: - lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n') + lexicon_transcript_file.write(vocab + " " + str(id) + "\n") id += 1 - with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file: - lexicon_frames_file.write(" 1" + '\n') - lexicon_frames_file.write(" 2" + '\n') - lexicon_frames_file.write(" 0" + '\n') + with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file: + lexicon_frames_file.write(" 1" + "\n") + lexicon_frames_file.write(" 2" + "\n") + lexicon_frames_file.write(" 0" + "\n") id = 3 for vocab in vocab_frames: - lexicon_frames_file.write(vocab + ' ' + str(id) + '\n') + lexicon_frames_file.write(vocab + " " + str(id) + "\n") id += 1 - + parser = argparse.ArgumentParser() -parser.add_argument('corpus_dir') -parser.add_argument('lm_dir') +parser.add_argument("corpus_dir") +parser.add_argument("lm_dir") + def main(): args = parser.parse_args() - + generate_lexicon(args.corpus_dir, args.lm_dir) -main() \ No newline at end of file + +main() diff --git a/egs/fluent_speech_commands/SLU/local/prepare_lang.py b/egs/fluent_speech_commands/SLU/local/prepare_lang.py index 61aafb8ed..2a71dcf81 100755 --- a/egs/fluent_speech_commands/SLU/local/prepare_lang.py +++ b/egs/fluent_speech_commands/SLU/local/prepare_lang.py @@ -19,11 +19,11 @@ consisting of words and tokens (i.e., phones) and does the following: 5. Generate L_disambig.pt, in k2 format. """ +import argparse import math from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple -import argparse import k2 import torch @@ -299,8 +299,10 @@ def lexicon_to_fst( fsa = k2.Fsa.from_str(arcs, acceptor=False) return fsa + parser = argparse.ArgumentParser() -parser.add_argument('lm_dir') +parser.add_argument("lm_dir") + def main(): args = parser.parse_args() @@ -312,58 +314,58 @@ def main(): sil_prob = 0.5 for name, lexicon_filename in zip(names, lexicon_filenames): - lexicon = read_lexicon(lexicon_filename) - tokens = get_words(lexicon) - words = get_words(lexicon) - new_lexicon = [] - for lexicon_item in lexicon: - new_lexicon.append((lexicon_item[0], [lexicon_item[0]])) - lexicon = new_lexicon + lexicon = read_lexicon(lexicon_filename) + tokens = get_words(lexicon) + words = get_words(lexicon) + new_lexicon = [] + for lexicon_item in lexicon: + new_lexicon.append((lexicon_item[0], [lexicon_item[0]])) + lexicon = new_lexicon - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") - tokens = [""] + tokens - words = ['eps'] + words + ["#0", "!SIL"] + tokens = [""] + tokens + words = ["eps"] + words + ["#0", "!SIL"] - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) - write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id) - write_mapping(out_dir / ("words_" + name + ".txt"), word2id) - write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig) + write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id) + write_mapping(out_dir / ("words_" + name + ".txt"), word2id) + write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig) - L = lexicon_to_fst( - lexicon, - token2id=word2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) + L = lexicon_to_fst( + lexicon, + token2id=word2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=word2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt")) - torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt")) + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=word2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt")) + torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt")) - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(out_dir / "L.png", title="L") - L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") + if False: + # Just for debugging, will remove it + L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt") + L_disambig.labels_sym = L.labels_sym + L_disambig.aux_labels_sym = L.aux_labels_sym + L.draw(out_dir / "L.png", title="L") + L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig") main() diff --git a/egs/fluent_speech_commands/SLU/transducer/beam_search.py b/egs/fluent_speech_commands/SLU/transducer/beam_search.py index 3c6e2b34d..a16aa0123 100755 --- a/egs/fluent_speech_commands/SLU/transducer/beam_search.py +++ b/egs/fluent_speech_commands/SLU/transducer/beam_search.py @@ -20,7 +20,9 @@ import torch from transducer.model import Transducer -def greedy_search(model: Transducer, encoder_out: torch.Tensor, id2word: dict) -> List[str]: +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, id2word: dict +) -> List[str]: """ Args: model: diff --git a/egs/fluent_speech_commands/SLU/transducer/decode.py b/egs/fluent_speech_commands/SLU/transducer/decode.py index 8e19cb526..ba2b9aaea 100755 --- a/egs/fluent_speech_commands/SLU/transducer/decode.py +++ b/egs/fluent_speech_commands/SLU/transducer/decode.py @@ -22,12 +22,12 @@ from typing import List, Tuple import torch import torch.nn as nn -from transducer.slu_datamodule import SluDataModule from transducer.beam_search import greedy_search -from transducer.decoder import Decoder from transducer.conformer import Conformer +from transducer.decoder import Decoder from transducer.joiner import Joiner from transducer.model import Transducer +from transducer.slu_datamodule import SluDataModule from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info @@ -45,7 +45,7 @@ def get_id2word(params): # 0 is blank id = 1 try: - with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: for line in lexicon_file: if len(line.strip()) > 0: id2word[id] = line.split()[0] @@ -82,11 +82,7 @@ def get_parser(): default="transducer/exp", help="Directory from which to load the checkpoints", ) - parser.add_argument( - "--lang-dir", - type=str, - default="data/lm/frames" - ) + parser.add_argument("--lang-dir", type=str, default="data/lm/frames") return parser @@ -106,9 +102,11 @@ def get_params() -> AttributeDict: ) vocab_size = 1 - with open(params.lang_dir / 'lexicon_disambig.txt') as lexicon_file: + with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file: for line in lexicon_file: - if len(line.strip()) > 0:# and '' not in line and '' not in line and '' not in line: + if ( + len(line.strip()) > 0 + ): # and '' not in line and '' not in line and '' not in line: vocab_size += 1 params.vocab_size = vocab_size @@ -116,10 +114,7 @@ def get_params() -> AttributeDict: def decode_one_batch( - params: AttributeDict, - model: nn.Module, - batch: dict, - id2word: dict + params: AttributeDict, model: nn.Module, batch: dict, id2word: dict ) -> List[List[int]]: """Decode one batch and return the result in a list-of-list. Each sub list contains the word IDs for an utterance in the batch. @@ -195,15 +190,18 @@ def decode_dataset( results = [] for batch_idx, batch in enumerate(dl): - texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]] - texts = [' ' + a.replace('change language', 'change_language') + ' ' for a in texts] + texts = [ + " ".join(a.supervisions[0].custom["frames"]) + for a in batch["supervisions"]["cut"] + ] + texts = [ + " " + a.replace("change language", "change_language") + " " + for a in texts + ] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps = decode_one_batch( - params=params, - model=model, - batch=batch, - id2word=id2word + params=params, model=model, batch=batch, id2word=id2word ) this_batch = [] @@ -338,7 +336,7 @@ def main(): model=model, ) - test_set_name=str(args.feature_dir).split('/')[-2] + test_set_name = str(args.feature_dir).split("/")[-2] save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results) logging.info("Done!") diff --git a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py index bffd52e4c..fa715abdd 100755 --- a/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py +++ b/egs/fluent_speech_commands/SLU/transducer/slu_datamodule.py @@ -282,11 +282,8 @@ class SluDataModule(DataModule): ) return cuts_valid - @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") - cuts_test = load_manifest_lazy( - self.args.feature_dir / "slu_cuts_test.jsonl.gz" - ) + cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz") return cuts_test diff --git a/egs/fluent_speech_commands/SLU/transducer/train.py b/egs/fluent_speech_commands/SLU/transducer/train.py index f2dd0ca67..a59c0b754 100755 --- a/egs/fluent_speech_commands/SLU/transducer/train.py +++ b/egs/fluent_speech_commands/SLU/transducer/train.py @@ -26,14 +26,15 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim -from slu_datamodule import SluDataModule from lhotse.utils import fix_random_seed +from slu_datamodule import SluDataModule from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ +from transducer.conformer import Conformer + # from torch.utils.tensorboard import SummaryWriter from transducer.decoder import Decoder -from transducer.conformer import Conformer from transducer.joiner import Joiner from transducer.model import Transducer @@ -49,20 +50,20 @@ def get_word2id(params): # 0 is blank id = 1 - with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: for line in lexicon_file: if len(line.strip()) > 0: word2id[line.split()[0]] = id id += 1 - return word2id + return word2id def get_labels(texts: List[str], word2id) -> k2.RaggedTensor: """ Args: texts: - A list of transcripts. + A list of transcripts. Returns: Return a ragged tensor containing the corresponding word ID. """ @@ -133,11 +134,7 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) - parser.add_argument( - "--lang-dir", - type=str, - default="data/lm/frames" - ) + parser.add_argument("--lang-dir", type=str, default="data/lm/frames") return parser @@ -215,9 +212,11 @@ def get_params() -> AttributeDict: ) vocab_size = 1 - with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: + with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file: for line in lexicon_file: - if len(line.strip()) > 0:# and '' not in line and '' not in line and '' not in line: + if ( + len(line.strip()) > 0 + ): # and '' not in line and '' not in line and '' not in line: vocab_size += 1 params.vocab_size = vocab_size @@ -312,11 +311,7 @@ def save_checkpoint( def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - is_training: bool, - word2ids + params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids ) -> Tuple[Tensor, MetricsTracker]: """ Compute RNN-T loss given the model and its inputs. @@ -342,8 +337,14 @@ def compute_loss( feature_lens = batch["supervisions"]["num_frames"].to(device) - texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]] - texts = [' ' + a.replace('change language', 'change_language') + ' ' for a in texts] + texts = [ + " ".join(a.supervisions[0].custom["frames"]) + for a in batch["supervisions"]["cut"] + ] + texts = [ + " " + a.replace("change language", "change_language") + " " + for a in texts + ] labels = get_labels(texts, word2ids).to(device) with torch.set_grad_enabled(is_training): @@ -378,7 +379,7 @@ def compute_validation_loss( model=model, batch=batch, is_training=False, - word2ids=word2ids + word2ids=word2ids, ) assert loss.requires_grad is False @@ -437,11 +438,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - word2ids=word2ids + params=params, model=model, batch=batch, is_training=True, word2ids=word2ids ) # summary stats. tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -471,7 +468,7 @@ def train_one_epoch( model=model, valid_dl=valid_dl, world_size=world_size, - word2ids=word2ids + word2ids=word2ids, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") @@ -593,7 +590,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, - word2ids=word2ids + word2ids=word2ids, ) save_checkpoint(