mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Fix style check issues
Signed-off-by: Xinyuan Li <xli257@b17.clsp.jhu.edu>
This commit is contained in:
parent
7047a579b8
commit
eec59410f1
@ -7,8 +7,9 @@ It looks for manifests in the directory data/manifests.
|
|||||||
The generated fbank features are saved in data/fbank.
|
The generated fbank features are saved in data/fbank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os, argparse
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -82,9 +83,10 @@ def compute_fbank_slu(manifest_dir, fbanks_dir):
|
|||||||
)
|
)
|
||||||
cut_set.to_file(cuts_file)
|
cut_set.to_file(cuts_file)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('manifest_dir')
|
parser.add_argument("manifest_dir")
|
||||||
parser.add_argument('fbanks_dir')
|
parser.add_argument("fbanks_dir")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
@ -1,12 +1,22 @@
|
|||||||
import pandas, argparse
|
import argparse
|
||||||
|
|
||||||
|
import pandas
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def generate_lexicon(corpus_dir, lm_dir):
|
def generate_lexicon(corpus_dir, lm_dir):
|
||||||
data = pandas.read_csv(str(corpus_dir) + '/data/train_data.csv', index_col = 0, header = 0)
|
data = pandas.read_csv(
|
||||||
|
str(corpus_dir) + "/data/train_data.csv", index_col=0, header=0
|
||||||
|
)
|
||||||
vocab_transcript = set()
|
vocab_transcript = set()
|
||||||
vocab_frames = set()
|
vocab_frames = set()
|
||||||
transcripts = data['transcription'].tolist()
|
transcripts = data["transcription"].tolist()
|
||||||
frames = list(i for i in zip(data['action'].tolist(), data['object'].tolist(), data['location'].tolist()))
|
frames = list(
|
||||||
|
i
|
||||||
|
for i in zip(
|
||||||
|
data["action"].tolist(), data["object"].tolist(), data["location"].tolist()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for transcript in tqdm(transcripts):
|
for transcript in tqdm(transcripts):
|
||||||
for word in transcript.split():
|
for word in transcript.split():
|
||||||
@ -14,34 +24,36 @@ def generate_lexicon(corpus_dir, lm_dir):
|
|||||||
|
|
||||||
for frame in tqdm(frames):
|
for frame in tqdm(frames):
|
||||||
for word in frame:
|
for word in frame:
|
||||||
vocab_frames.add('_'.join(word.split()))
|
vocab_frames.add("_".join(word.split()))
|
||||||
|
|
||||||
with open(lm_dir + '/words_transcript.txt', 'w') as lexicon_transcript_file:
|
with open(lm_dir + "/words_transcript.txt", "w") as lexicon_transcript_file:
|
||||||
lexicon_transcript_file.write("<UNK> 1" + '\n')
|
lexicon_transcript_file.write("<UNK> 1" + "\n")
|
||||||
lexicon_transcript_file.write("<s> 2" + '\n')
|
lexicon_transcript_file.write("<s> 2" + "\n")
|
||||||
lexicon_transcript_file.write("</s> 0" + '\n')
|
lexicon_transcript_file.write("</s> 0" + "\n")
|
||||||
id = 3
|
id = 3
|
||||||
for vocab in vocab_transcript:
|
for vocab in vocab_transcript:
|
||||||
lexicon_transcript_file.write(vocab + ' ' + str(id) + '\n')
|
lexicon_transcript_file.write(vocab + " " + str(id) + "\n")
|
||||||
id += 1
|
id += 1
|
||||||
|
|
||||||
with open(lm_dir + '/words_frames.txt', 'w') as lexicon_frames_file:
|
with open(lm_dir + "/words_frames.txt", "w") as lexicon_frames_file:
|
||||||
lexicon_frames_file.write("<UNK> 1" + '\n')
|
lexicon_frames_file.write("<UNK> 1" + "\n")
|
||||||
lexicon_frames_file.write("<s> 2" + '\n')
|
lexicon_frames_file.write("<s> 2" + "\n")
|
||||||
lexicon_frames_file.write("</s> 0" + '\n')
|
lexicon_frames_file.write("</s> 0" + "\n")
|
||||||
id = 3
|
id = 3
|
||||||
for vocab in vocab_frames:
|
for vocab in vocab_frames:
|
||||||
lexicon_frames_file.write(vocab + ' ' + str(id) + '\n')
|
lexicon_frames_file.write(vocab + " " + str(id) + "\n")
|
||||||
id += 1
|
id += 1
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('corpus_dir')
|
parser.add_argument("corpus_dir")
|
||||||
parser.add_argument('lm_dir')
|
parser.add_argument("lm_dir")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_lexicon(args.corpus_dir, args.lm_dir)
|
generate_lexicon(args.corpus_dir, args.lm_dir)
|
||||||
|
|
||||||
main()
|
|
||||||
|
main()
|
||||||
|
@ -19,11 +19,11 @@ consisting of words and tokens (i.e., phones) and does the following:
|
|||||||
|
|
||||||
5. Generate L_disambig.pt, in k2 format.
|
5. Generate L_disambig.pt, in k2 format.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
import argparse
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -299,8 +299,10 @@ def lexicon_to_fst(
|
|||||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
return fsa
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('lm_dir')
|
parser.add_argument("lm_dir")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -312,58 +314,58 @@ def main():
|
|||||||
sil_prob = 0.5
|
sil_prob = 0.5
|
||||||
|
|
||||||
for name, lexicon_filename in zip(names, lexicon_filenames):
|
for name, lexicon_filename in zip(names, lexicon_filenames):
|
||||||
lexicon = read_lexicon(lexicon_filename)
|
lexicon = read_lexicon(lexicon_filename)
|
||||||
tokens = get_words(lexicon)
|
tokens = get_words(lexicon)
|
||||||
words = get_words(lexicon)
|
words = get_words(lexicon)
|
||||||
new_lexicon = []
|
new_lexicon = []
|
||||||
for lexicon_item in lexicon:
|
for lexicon_item in lexicon:
|
||||||
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
|
new_lexicon.append((lexicon_item[0], [lexicon_item[0]]))
|
||||||
lexicon = new_lexicon
|
lexicon = new_lexicon
|
||||||
|
|
||||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
for i in range(max_disambig + 1):
|
for i in range(max_disambig + 1):
|
||||||
disambig = f"#{i}"
|
disambig = f"#{i}"
|
||||||
assert disambig not in tokens
|
assert disambig not in tokens
|
||||||
tokens.append(f"#{i}")
|
tokens.append(f"#{i}")
|
||||||
|
|
||||||
tokens = ["<eps>"] + tokens
|
tokens = ["<eps>"] + tokens
|
||||||
words = ['eps'] + words + ["#0", "!SIL"]
|
words = ["eps"] + words + ["#0", "!SIL"]
|
||||||
|
|
||||||
token2id = generate_id_map(tokens)
|
token2id = generate_id_map(tokens)
|
||||||
word2id = generate_id_map(words)
|
word2id = generate_id_map(words)
|
||||||
|
|
||||||
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
|
write_mapping(out_dir / ("tokens_" + name + ".txt"), token2id)
|
||||||
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
|
write_mapping(out_dir / ("words_" + name + ".txt"), word2id)
|
||||||
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
|
write_lexicon(out_dir / ("lexicon_disambig_" + name + ".txt"), lexicon_disambig)
|
||||||
|
|
||||||
L = lexicon_to_fst(
|
L = lexicon_to_fst(
|
||||||
lexicon,
|
lexicon,
|
||||||
token2id=word2id,
|
token2id=word2id,
|
||||||
word2id=word2id,
|
word2id=word2id,
|
||||||
sil_token=sil_token,
|
sil_token=sil_token,
|
||||||
sil_prob=sil_prob,
|
sil_prob=sil_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
L_disambig = lexicon_to_fst(
|
L_disambig = lexicon_to_fst(
|
||||||
lexicon_disambig,
|
lexicon_disambig,
|
||||||
token2id=word2id,
|
token2id=word2id,
|
||||||
word2id=word2id,
|
word2id=word2id,
|
||||||
sil_token=sil_token,
|
sil_token=sil_token,
|
||||||
sil_prob=sil_prob,
|
sil_prob=sil_prob,
|
||||||
need_self_loops=True,
|
need_self_loops=True,
|
||||||
)
|
)
|
||||||
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
|
torch.save(L.as_dict(), out_dir / ("L_" + name + ".pt"))
|
||||||
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
|
torch.save(L_disambig.as_dict(), out_dir / ("L_disambig_" + name + ".pt"))
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
# Just for debugging, will remove it
|
# Just for debugging, will remove it
|
||||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
||||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||||
L_disambig.labels_sym = L.labels_sym
|
L_disambig.labels_sym = L.labels_sym
|
||||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||||
L.draw(out_dir / "L.png", title="L")
|
L.draw(out_dir / "L.png", title="L")
|
||||||
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
@ -20,7 +20,9 @@ import torch
|
|||||||
from transducer.model import Transducer
|
from transducer.model import Transducer
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor, id2word: dict) -> List[str]:
|
def greedy_search(
|
||||||
|
model: Transducer, encoder_out: torch.Tensor, id2word: dict
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
|
@ -22,12 +22,12 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transducer.slu_datamodule import SluDataModule
|
|
||||||
from transducer.beam_search import greedy_search
|
from transducer.beam_search import greedy_search
|
||||||
from transducer.decoder import Decoder
|
|
||||||
from transducer.conformer import Conformer
|
from transducer.conformer import Conformer
|
||||||
|
from transducer.decoder import Decoder
|
||||||
from transducer.joiner import Joiner
|
from transducer.joiner import Joiner
|
||||||
from transducer.model import Transducer
|
from transducer.model import Transducer
|
||||||
|
from transducer.slu_datamodule import SluDataModule
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
@ -45,7 +45,7 @@ def get_id2word(params):
|
|||||||
# 0 is blank
|
# 0 is blank
|
||||||
id = 1
|
id = 1
|
||||||
try:
|
try:
|
||||||
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
|
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
|
||||||
for line in lexicon_file:
|
for line in lexicon_file:
|
||||||
if len(line.strip()) > 0:
|
if len(line.strip()) > 0:
|
||||||
id2word[id] = line.split()[0]
|
id2word[id] = line.split()[0]
|
||||||
@ -82,11 +82,7 @@ def get_parser():
|
|||||||
default="transducer/exp",
|
default="transducer/exp",
|
||||||
help="Directory from which to load the checkpoints",
|
help="Directory from which to load the checkpoints",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/lm/frames"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -106,9 +102,11 @@ def get_params() -> AttributeDict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
vocab_size = 1
|
vocab_size = 1
|
||||||
with open(params.lang_dir / 'lexicon_disambig.txt') as lexicon_file:
|
with open(params.lang_dir / "lexicon_disambig.txt") as lexicon_file:
|
||||||
for line in lexicon_file:
|
for line in lexicon_file:
|
||||||
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
if (
|
||||||
|
len(line.strip()) > 0
|
||||||
|
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
||||||
vocab_size += 1
|
vocab_size += 1
|
||||||
params.vocab_size = vocab_size
|
params.vocab_size = vocab_size
|
||||||
|
|
||||||
@ -116,10 +114,7 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict, model: nn.Module, batch: dict, id2word: dict
|
||||||
model: nn.Module,
|
|
||||||
batch: dict,
|
|
||||||
id2word: dict
|
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Decode one batch and return the result in a list-of-list.
|
"""Decode one batch and return the result in a list-of-list.
|
||||||
Each sub list contains the word IDs for an utterance in the batch.
|
Each sub list contains the word IDs for an utterance in the batch.
|
||||||
@ -195,15 +190,18 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
|
texts = [
|
||||||
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
|
" ".join(a.supervisions[0].custom["frames"])
|
||||||
|
for a in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
|
texts = [
|
||||||
|
"<s> " + a.replace("change language", "change_language") + " </s>"
|
||||||
|
for a in texts
|
||||||
|
]
|
||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
hyps = decode_one_batch(
|
hyps = decode_one_batch(
|
||||||
params=params,
|
params=params, model=model, batch=batch, id2word=id2word
|
||||||
model=model,
|
|
||||||
batch=batch,
|
|
||||||
id2word=id2word
|
|
||||||
)
|
)
|
||||||
|
|
||||||
this_batch = []
|
this_batch = []
|
||||||
@ -338,7 +336,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_set_name=str(args.feature_dir).split('/')[-2]
|
test_set_name = str(args.feature_dir).split("/")[-2]
|
||||||
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
|
save_results(exp_dir=params.exp_dir, test_set_name=test_set_name, results=results)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -282,11 +282,8 @@ class SluDataModule(DataModule):
|
|||||||
)
|
)
|
||||||
return cuts_valid
|
return cuts_valid
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
def test_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
cuts_test = load_manifest_lazy(
|
cuts_test = load_manifest_lazy(self.args.feature_dir / "slu_cuts_test.jsonl.gz")
|
||||||
self.args.feature_dir / "slu_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
||||||
return cuts_test
|
return cuts_test
|
||||||
|
@ -26,14 +26,15 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from slu_datamodule import SluDataModule
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from slu_datamodule import SluDataModule
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from transducer.conformer import Conformer
|
||||||
|
|
||||||
# from torch.utils.tensorboard import SummaryWriter
|
# from torch.utils.tensorboard import SummaryWriter
|
||||||
from transducer.decoder import Decoder
|
from transducer.decoder import Decoder
|
||||||
from transducer.conformer import Conformer
|
|
||||||
from transducer.joiner import Joiner
|
from transducer.joiner import Joiner
|
||||||
from transducer.model import Transducer
|
from transducer.model import Transducer
|
||||||
|
|
||||||
@ -49,20 +50,20 @@ def get_word2id(params):
|
|||||||
|
|
||||||
# 0 is blank
|
# 0 is blank
|
||||||
id = 1
|
id = 1
|
||||||
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
|
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
|
||||||
for line in lexicon_file:
|
for line in lexicon_file:
|
||||||
if len(line.strip()) > 0:
|
if len(line.strip()) > 0:
|
||||||
word2id[line.split()[0]] = id
|
word2id[line.split()[0]] = id
|
||||||
id += 1
|
id += 1
|
||||||
|
|
||||||
return word2id
|
return word2id
|
||||||
|
|
||||||
|
|
||||||
def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
|
def get_labels(texts: List[str], word2id) -> k2.RaggedTensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
texts:
|
texts:
|
||||||
A list of transcripts.
|
A list of transcripts.
|
||||||
Returns:
|
Returns:
|
||||||
Return a ragged tensor containing the corresponding word ID.
|
Return a ragged tensor containing the corresponding word ID.
|
||||||
"""
|
"""
|
||||||
@ -133,11 +134,7 @@ def get_parser():
|
|||||||
help="The seed for random generators intended for reproducibility",
|
help="The seed for random generators intended for reproducibility",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument("--lang-dir", type=str, default="data/lm/frames")
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
default="data/lm/frames"
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -215,9 +212,11 @@ def get_params() -> AttributeDict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
vocab_size = 1
|
vocab_size = 1
|
||||||
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
|
with open(Path(params.lang_dir) / "lexicon_disambig.txt") as lexicon_file:
|
||||||
for line in lexicon_file:
|
for line in lexicon_file:
|
||||||
if len(line.strip()) > 0:# and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
if (
|
||||||
|
len(line.strip()) > 0
|
||||||
|
): # and '<UNK>' not in line and '<s>' not in line and '</s>' not in line:
|
||||||
vocab_size += 1
|
vocab_size += 1
|
||||||
params.vocab_size = vocab_size
|
params.vocab_size = vocab_size
|
||||||
|
|
||||||
@ -312,11 +311,7 @@ def save_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict, model: nn.Module, batch: dict, is_training: bool, word2ids
|
||||||
model: nn.Module,
|
|
||||||
batch: dict,
|
|
||||||
is_training: bool,
|
|
||||||
word2ids
|
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute RNN-T loss given the model and its inputs.
|
Compute RNN-T loss given the model and its inputs.
|
||||||
@ -342,8 +337,14 @@ def compute_loss(
|
|||||||
|
|
||||||
feature_lens = batch["supervisions"]["num_frames"].to(device)
|
feature_lens = batch["supervisions"]["num_frames"].to(device)
|
||||||
|
|
||||||
texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]]
|
texts = [
|
||||||
texts = ['<s> ' + a.replace('change language', 'change_language') + ' </s>' for a in texts]
|
" ".join(a.supervisions[0].custom["frames"])
|
||||||
|
for a in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
|
texts = [
|
||||||
|
"<s> " + a.replace("change language", "change_language") + " </s>"
|
||||||
|
for a in texts
|
||||||
|
]
|
||||||
labels = get_labels(texts, word2ids).to(device)
|
labels = get_labels(texts, word2ids).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
@ -378,7 +379,7 @@ def compute_validation_loss(
|
|||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
word2ids=word2ids
|
word2ids=word2ids,
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
|
|
||||||
@ -437,11 +438,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params, model=model, batch=batch, is_training=True, word2ids=word2ids
|
||||||
model=model,
|
|
||||||
batch=batch,
|
|
||||||
is_training=True,
|
|
||||||
word2ids=word2ids
|
|
||||||
)
|
)
|
||||||
# summary stats.
|
# summary stats.
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -471,7 +468,7 @@ def train_one_epoch(
|
|||||||
model=model,
|
model=model,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
word2ids=word2ids
|
word2ids=word2ids,
|
||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||||
@ -593,7 +590,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
word2ids=word2ids
|
word2ids=word2ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user