mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
update
This commit is contained in:
parent
bea1bd295f
commit
6579800720
@ -72,16 +72,12 @@ class LibriHeavyAsrDataModule:
|
|||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
if args.use_context_list:
|
if args.use_context_list:
|
||||||
from dataset2 import PromptASRDataset
|
|
||||||
|
|
||||||
assert args.rare_word_file is not None
|
assert args.rare_word_file is not None
|
||||||
with open(args.rare_word_file, "r") as f:
|
with open(args.rare_word_file, "r") as f:
|
||||||
self.rare_word_list = (
|
self.rare_word_list = (
|
||||||
f.read().lower().split()
|
f.read().lower().split()
|
||||||
) # Use lower-cased for easier style transform
|
) # Use lower-cased for easier style transform
|
||||||
else:
|
else:
|
||||||
from dataset import PromptASRDataset
|
|
||||||
|
|
||||||
self.rare_word_list = None
|
self.rare_word_list = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -20,22 +20,55 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./zipformer_prompt_asr/decode_bert.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./zipformer_prompt_asr/exp \
|
||||||
--max-duration 600 \
|
--max-duration 1000 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) modified beam search
|
(2) modified beam search
|
||||||
./pruned_transducer_stateless7/decode.py \
|
./zipformer_prompt_asr/decode_bert.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
--exp-dir ./zipformer_prompt_asr/exp \
|
||||||
--max-duration 600 \
|
--max-duration 1000 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) Decode LibriSpeech
|
||||||
|
|
||||||
|
./zipformer_prompt_asr/decode_bert.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer_prompt_asr/exp \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--use-ls-test-set True \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) Decode LibriSpeech + biasing list
|
||||||
|
|
||||||
|
biasing_list=100 # could also be 0
|
||||||
|
|
||||||
|
./zipformer_prompt_asr/decode_bert.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./zipformer_prompt_asr/exp \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4 \\
|
||||||
|
--use-ls-test-set True \
|
||||||
|
--use-ls-context-list True \
|
||||||
|
--biasing-level utterance \
|
||||||
|
--ls-distractors $biasing_list \
|
||||||
|
--post-normalization True \
|
||||||
|
--text-encoder-type BERT \
|
||||||
|
--max-prompt-lens 1000 \
|
||||||
|
--style-text-transform mixed-punc \
|
||||||
|
--pre-text-transform mixed-punc
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -45,40 +78,34 @@ import math
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Callable
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BertTokenizer, BertModel
|
|
||||||
from asr_datamodule import LibriHeavyAsrDataModule
|
from asr_datamodule import LibriHeavyAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import greedy_search, greedy_search_batch, modified_beam_search
|
||||||
greedy_search,
|
|
||||||
greedy_search_with_context,
|
|
||||||
greedy_search_batch,
|
|
||||||
greedy_search_batch_with_context,
|
|
||||||
modified_beam_search,
|
|
||||||
)
|
|
||||||
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
from dataset import naive_triplet_text_sampling, random_shuffle_subset
|
||||||
from utils import get_facebook_biasing_list, brian_biasing_list, write_error_stats
|
|
||||||
from ls_text_normalization import word_normalization
|
from ls_text_normalization import word_normalization
|
||||||
from text_normalization import (
|
from text_normalization import (
|
||||||
ref_text_normalization,
|
|
||||||
remove_non_alphabetic,
|
|
||||||
upper_only_alpha,
|
|
||||||
upper_all_char,
|
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
lower_only_alpha,
|
lower_only_alpha,
|
||||||
|
ref_text_normalization,
|
||||||
|
remove_non_alphabetic,
|
||||||
train_text_normalization,
|
train_text_normalization,
|
||||||
|
upper_all_char,
|
||||||
|
upper_only_alpha,
|
||||||
)
|
)
|
||||||
from train_bert_encoder_with_style import (
|
from train_bert_encoder import (
|
||||||
|
_encode_texts_as_bytes_with_tokenizer,
|
||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
get_params,
|
get_params,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_transducer_model,
|
get_transducer_model,
|
||||||
_encode_texts_as_bytes_with_tokenizer,
|
|
||||||
)
|
)
|
||||||
|
from transformers import BertModel, BertTokenizer
|
||||||
|
from utils import brian_biasing_list, get_facebook_biasing_list, write_error_stats
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -87,15 +114,11 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import AttributeDict, setup_logger, store_transcripts, str2bool
|
||||||
AttributeDict,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -267,7 +290,7 @@ def get_parser():
|
|||||||
"--use-style-prompt",
|
"--use-style-prompt",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="Use style prompt when evaluation"
|
help="Use style prompt when evaluation",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -276,13 +299,6 @@ def get_parser():
|
|||||||
default=1000,
|
default=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-context-embedding",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Use context fuser when evaluation"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--post-normalization",
|
"--post-normalization",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -290,47 +306,41 @@ def get_parser():
|
|||||||
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
help="Normalized the recognition results by uppercasing and removing non-alphabetic symbols. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--long-audio-recog",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--compute-CER",
|
"--compute-CER",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=False,
|
||||||
help="Reports CER. By default, only reports WER",
|
help="Reports CER. By default, only reports WER",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--style-text-transform",
|
"--style-text-transform",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||||
default="mixed-punc",
|
default="mixed-punc",
|
||||||
help="The style of style prompt, i.e style_text"
|
help="The style of style prompt, i.e style_text",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pre-text-transform",
|
"--pre-text-transform",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["mixed-punc", "upper-no-punc", "lower-no-punc","lower-punc"],
|
choices=["mixed-punc", "upper-no-punc", "lower-no-punc", "lower-punc"],
|
||||||
default="mixed-punc",
|
default="mixed-punc",
|
||||||
help="The style of content prompt, i.e pre_text"
|
help="The style of content prompt, i.e pre_text",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ls-test-set",
|
"--use-ls-test-set",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Use librispeech test set for evaluation."
|
help="Use librispeech test set for evaluation.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ls-context-list",
|
"--use-ls-context-list",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="If use a fixed context list for LibriSpeech decoding"
|
help="If use a fixed context list for LibriSpeech decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -344,13 +354,14 @@ def get_parser():
|
|||||||
"--ls-distractors",
|
"--ls-distractors",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="The number of distractors into context list for LibriSpeech decoding"
|
help="The number of distractors into context list for LibriSpeech decoding",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
def _apply_style_transform(text: List[str], transform: str) -> List[str]:
|
||||||
"""Apply transform to a list of text. By default, the text are in
|
"""Apply transform to a list of text. By default, the text are in
|
||||||
ground truth format, i.e mixed-punc.
|
ground truth format, i.e mixed-punc.
|
||||||
@ -378,7 +389,7 @@ def decode_one_batch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
tokenizer,
|
tokenizer: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
biasing_dict: dict = None,
|
biasing_dict: dict = None,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
@ -401,10 +412,15 @@ def decode_one_batch(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
tokenizer:
|
||||||
|
Tokenizer for the text encoder
|
||||||
batch:
|
batch:
|
||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
biasing_dict:
|
||||||
|
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
|
||||||
|
of biasing words (separated with space)
|
||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
@ -428,43 +444,48 @@ def decode_one_batch(
|
|||||||
cut_ids = [c.supervisions[0].id for c in cuts]
|
cut_ids = [c.supervisions[0].id for c in cuts]
|
||||||
batch_size = feature.size(0)
|
batch_size = feature.size(0)
|
||||||
|
|
||||||
# get pre_text
|
|
||||||
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
if "pre_text" in batch["supervisions"] and params.use_pre_text:
|
||||||
pre_texts = batch["supervisions"]["pre_text"]
|
pre_texts = batch["supervisions"]["pre_text"]
|
||||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||||
else:
|
else:
|
||||||
pre_texts = ["" for _ in range(batch_size)]
|
pre_texts = ["" for _ in range(batch_size)]
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list and params.use_ls_test_set:
|
||||||
if params.biasing_level == "utterance":
|
if params.biasing_level == "utterance":
|
||||||
pre_texts = [biasing_dict[id] for id in cut_ids]
|
pre_texts = [biasing_dict[id] for id in cut_ids]
|
||||||
elif params.biasing_level == "Chapter":
|
elif params.biasing_level == "Chapter":
|
||||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
chapter_ids = [c.split("-")[1] for c in cut_ids]
|
||||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||||
elif params.biasing_level == "Book":
|
elif params.biasing_level == "Book":
|
||||||
chapter_ids = [c.split('-')[1] for c in cut_ids]
|
chapter_ids = [c.split("-")[1] for c in cut_ids]
|
||||||
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
pre_texts = [biasing_dict[id] for id in chapter_ids]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unseen biasing level: {params.biasing_level}")
|
||||||
if params.pre_text_transform == "mixed-punc":
|
if params.pre_text_transform == "mixed-punc":
|
||||||
pre_texts = [t.lower() for t in pre_texts]
|
pre_texts = [t.lower() for t in pre_texts]
|
||||||
|
|
||||||
# get style_text
|
# get style_text
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
fixed_sentence = "Mixed-case English transcription, with punctuation. Actually, it's fully not related."
|
||||||
style_texts = batch["supervisions"].get("style_text", [fixed_sentence for _ in range(batch_size)])
|
style_texts = batch["supervisions"].get(
|
||||||
|
"style_text", [fixed_sentence for _ in range(batch_size)]
|
||||||
|
)
|
||||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||||
else:
|
else:
|
||||||
style_texts = ["" for _ in range(batch_size)] # use empty string
|
style_texts = ["" for _ in range(batch_size)] # use empty string
|
||||||
|
|
||||||
# Get the text embedding input
|
# Get the text embedding
|
||||||
if params.use_pre_text or params.use_style_prompt:
|
if params.use_pre_text or params.use_style_prompt:
|
||||||
|
|
||||||
# apply style transform to the pre_text and style_text
|
# apply style transform to the pre_text and style_text
|
||||||
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
pre_texts = _apply_style_transform(pre_texts, params.pre_text_transform)
|
||||||
if not params.use_ls_context_list:
|
if not params.use_ls_context_list:
|
||||||
pre_texts = [t[:params.max_prompt_lens] for t in pre_texts]
|
pre_texts = [t[-params.max_prompt_lens :] for t in pre_texts]
|
||||||
#pre_texts = random_shuffle_subset(pre_texts, p=1.0, p_mask=0.0)
|
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
style_texts = _apply_style_transform(style_texts, params.style_text_transform)
|
style_texts = _apply_style_transform(
|
||||||
|
style_texts, params.style_text_transform
|
||||||
|
)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
@ -477,12 +498,14 @@ def decode_one_batch(
|
|||||||
device=device,
|
device=device,
|
||||||
no_limit=True,
|
no_limit=True,
|
||||||
)
|
)
|
||||||
logging.info(f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}")
|
logging.info(
|
||||||
|
f"Shape of the encoded prompts: {encoded_inputs['input_ids'].shape}"
|
||||||
|
)
|
||||||
|
|
||||||
memory, memory_key_padding_mask = model.encode_text(
|
memory, memory_key_padding_mask = model.encode_text(
|
||||||
encoded_inputs=encoded_inputs,
|
encoded_inputs=encoded_inputs,
|
||||||
style_lens=style_lens,
|
style_lens=style_lens,
|
||||||
) # (T,B,C)
|
) # (T,B,C)
|
||||||
else:
|
else:
|
||||||
memory = None
|
memory = None
|
||||||
memory_key_padding_mask = None
|
memory_key_padding_mask = None
|
||||||
@ -506,26 +529,12 @@ def decode_one_batch(
|
|||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if (
|
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
params.decoding_method == "greedy_search"
|
hyp_tokens = greedy_search_batch(
|
||||||
and params.max_sym_per_frame == 1
|
model=model,
|
||||||
):
|
encoder_out=encoder_out,
|
||||||
if memory is None or not params.use_context_embedding:
|
encoder_out_lens=encoder_out_lens,
|
||||||
hyp_tokens = greedy_search_batch(
|
)
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
memory = memory.permute(1,0,2) # (T,N,C) -> (N,T,C)
|
|
||||||
context = model.context_fuser(memory, padding_mask=memory_key_padding_mask) # (N,C)
|
|
||||||
context = model.joiner.context_proj(context) # (N,C)
|
|
||||||
hyp_tokens = greedy_search_batch_with_context(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out,
|
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
@ -545,25 +554,10 @@ def decode_one_batch(
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
if memory is None or not params.use_context_embedding:
|
hyp = greedy_search(
|
||||||
hyp = greedy_search(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cur_context = context[i:i+1, :]
|
|
||||||
hyp = greedy_search_with_context(
|
|
||||||
model=model,
|
|
||||||
encoder_out=encoder_out_i,
|
|
||||||
context=cur_context,
|
|
||||||
max_sym_per_frame=params.max_sym_per_frame,
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "beam_search":
|
|
||||||
hyp = beam_search(
|
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -582,7 +576,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
tokenizer,
|
tokenizer: spm.SentencePieceProcessor,
|
||||||
biasing_dict: Dict = None,
|
biasing_dict: Dict = None,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
@ -598,6 +592,11 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
tokenizer:
|
||||||
|
Tokenizer for the text encoder
|
||||||
|
biasing_dict:
|
||||||
|
A dictionary in the form `{cut_id: :w1 w2"}` that contains a list
|
||||||
|
of biasing words (separated with space)
|
||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
@ -627,7 +626,9 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"] # By default, this should be in mixed-punc format
|
texts = batch["supervisions"][
|
||||||
|
"text"
|
||||||
|
] # By default, this should be in mixed-punc format
|
||||||
|
|
||||||
# the style of ref_text should match style_text
|
# the style of ref_text should match style_text
|
||||||
texts = _apply_style_transform(texts, params.style_text_transform)
|
texts = _apply_style_transform(texts, params.style_text_transform)
|
||||||
@ -637,9 +638,13 @@ def decode_dataset(
|
|||||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
if not params.use_ls_test_set:
|
if not params.use_ls_test_set:
|
||||||
try:
|
try:
|
||||||
book_names = [cut.text_path.split('/')[-2] for cut in batch["supervisions"]["cut"]]
|
book_names = [
|
||||||
except:
|
cut.text_path.split("/")[-2] for cut in batch["supervisions"]["cut"]
|
||||||
book_names = [cut.id.split('/')[0] for cut in batch["supervisions"]["cut"]]
|
]
|
||||||
|
except AttributeError:
|
||||||
|
book_names = [
|
||||||
|
cut.id.split("/")[0] for cut in batch["supervisions"]["cut"]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
book_names = ["" for _ in cut_ids]
|
book_names = ["" for _ in cut_ids]
|
||||||
|
|
||||||
@ -657,7 +662,9 @@ def decode_dataset(
|
|||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for cut_id, book_name, hyp_words, ref_text in zip(cut_ids, book_names, hyps, texts):
|
for cut_id, book_name, hyp_words, ref_text in zip(
|
||||||
|
cut_ids, book_names, hyps, texts
|
||||||
|
):
|
||||||
ref_text = ref_text_normalization(
|
ref_text = ref_text_normalization(
|
||||||
ref_text
|
ref_text
|
||||||
) # remove full-width symbols & some book marks
|
) # remove full-width symbols & some book marks
|
||||||
@ -672,9 +679,7 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
||||||
)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -705,7 +710,9 @@ def save_results(
|
|||||||
|
|
||||||
if params.compute_CER:
|
if params.compute_CER:
|
||||||
# Write CER statistics
|
# Write CER statistics
|
||||||
recog_path = params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-char-{params.suffix}.txt"
|
||||||
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
store_transcripts(filename=recog_path, texts=results, char_level=True)
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
params.res_dir / f"errs-CER-{test_set_name}-{params.suffix}.txt"
|
||||||
@ -723,9 +730,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
logging.info("Wrote detailed CER stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -740,9 +745,7 @@ def save_results(
|
|||||||
|
|
||||||
if params.compute_CER:
|
if params.compute_CER:
|
||||||
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
|
||||||
)
|
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_cers:
|
for key, val in test_set_cers:
|
||||||
@ -771,10 +774,7 @@ def main():
|
|||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.long_audio_recog:
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
params.res_dir = params.exp_dir / (params.decoding_method + "long_audio")
|
|
||||||
else:
|
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
@ -792,22 +792,19 @@ def main():
|
|||||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
if "beam_search" in params.decoding_method:
|
if "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
if params.use_pre_text:
|
if params.use_pre_text:
|
||||||
params.suffix += f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
|
params.suffix += (
|
||||||
|
f"-pre-text-{params.pre_text_transform}-len-{params.max_prompt_lens}"
|
||||||
|
)
|
||||||
|
|
||||||
if params.use_style_prompt:
|
if params.use_style_prompt:
|
||||||
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
params.suffix += f"-style-prompt-{params.style_text_transform}"
|
||||||
|
|
||||||
if params.use_context_embedding:
|
|
||||||
params.suffix += f"-use-context-fuser"
|
|
||||||
|
|
||||||
if params.use_ls_context_list:
|
if params.use_ls_context_list:
|
||||||
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
params.suffix += f"-use-{params.biasing_level}-level-ls-context-list"
|
||||||
if params.biasing_level == "utterance" and params.ls_distractors:
|
if params.biasing_level == "utterance" and params.ls_distractors:
|
||||||
@ -841,9 +838,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -870,9 +867,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -935,18 +932,15 @@ def main():
|
|||||||
test_other_cuts = libriheavy.test_other_cuts()
|
test_other_cuts = libriheavy.test_other_cuts()
|
||||||
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
ls_test_clean_cuts = libriheavy.librispeech_test_clean_cuts()
|
||||||
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
ls_test_other_cuts = libriheavy.librispeech_test_other_cuts()
|
||||||
long_audio_cuts = libriheavy.long_audio_cuts()
|
|
||||||
|
|
||||||
npr1_dev_cuts = libriheavy.npr1_dev_cuts()
|
test_clean_dl = libriheavy.valid_dataloaders(
|
||||||
npr1_test_cuts = libriheavy.npr1_test_cuts()
|
test_clean_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||||
|
)
|
||||||
test_clean_dl = libriheavy.valid_dataloaders(test_clean_cuts, text_sampling_func=naive_triplet_text_sampling)
|
test_other_dl = libriheavy.valid_dataloaders(
|
||||||
test_other_dl = libriheavy.valid_dataloaders(test_other_cuts, text_sampling_func=naive_triplet_text_sampling)
|
test_other_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||||
|
)
|
||||||
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
ls_test_clean_dl = libriheavy.test_dataloaders(ls_test_clean_cuts)
|
||||||
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
ls_test_other_dl = libriheavy.test_dataloaders(ls_test_other_cuts)
|
||||||
long_audio_dl = libriheavy.valid_dataloaders(long_audio_cuts, text_sampling_func=naive_triplet_text_sampling)
|
|
||||||
npr1_dev_dl = libriheavy.valid_dataloaders(npr1_dev_cuts, text_sampling_func=naive_triplet_text_sampling)
|
|
||||||
npr1_test_dl = libriheavy.valid_dataloaders(npr1_test_cuts, text_sampling_func=naive_triplet_text_sampling)
|
|
||||||
|
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
test_sets = ["ls-test-clean", "ls-test-other"]
|
test_sets = ["ls-test-clean", "ls-test-other"]
|
||||||
@ -955,17 +949,19 @@ def main():
|
|||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_dl = [test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
if params.long_audio_recog:
|
|
||||||
test_sets = ["long-audio"]
|
|
||||||
test_dl = [long_audio_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
if test_set == "ls-test-clean":
|
biasing_dict = None
|
||||||
biasing_dict = get_facebook_biasing_list("test-clean", use_distractors=params.ls_distractors)
|
if params.use_ls_context_list:
|
||||||
elif test_set == "ls-test-other":
|
if test_set == "ls-test-clean":
|
||||||
biasing_dict = get_facebook_biasing_list("test-other", use_distractors=params.ls_distractors)
|
biasing_dict = get_facebook_biasing_list(
|
||||||
else:
|
test_set="test-clean",
|
||||||
biasing_dict = None
|
num_distractors=params.ls_distractors,
|
||||||
|
)
|
||||||
|
elif test_set == "ls-test-other":
|
||||||
|
biasing_dict = get_facebook_biasing_list(
|
||||||
|
test_set="test-other",
|
||||||
|
num_distractors=params.ls_distractors,
|
||||||
|
)
|
||||||
|
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
@ -994,7 +990,9 @@ def main():
|
|||||||
for item in results_dict[k]:
|
for item in results_dict[k]:
|
||||||
id, ref, hyp = item
|
id, ref, hyp = item
|
||||||
if params.use_ls_test_set:
|
if params.use_ls_test_set:
|
||||||
hyp = " ".join(hyp).replace("-", " ").split() # handle the hypens
|
hyp = (
|
||||||
|
" ".join(hyp).replace("-", " ").split()
|
||||||
|
) # handle the hypens
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||||
hyp = [word_normalization(w.upper()) for w in hyp]
|
hyp = [word_normalization(w.upper()) for w in hyp]
|
||||||
hyp = " ".join(hyp).split()
|
hyp = " ".join(hyp).split()
|
||||||
@ -1003,7 +1001,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
hyp = upper_only_alpha(" ".join(hyp)).split()
|
hyp = upper_only_alpha(" ".join(hyp)).split()
|
||||||
ref = upper_only_alpha(" ".join(ref)).split()
|
ref = upper_only_alpha(" ".join(ref)).split()
|
||||||
new_ans.append((id,ref,hyp))
|
new_ans.append((id, ref, hyp))
|
||||||
new_res[k] = new_ans
|
new_res[k] = new_ans
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -22,24 +22,43 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless7/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 300
|
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless7/train.py \
|
(1) Non-streaming model, without context list
|
||||||
|
|
||||||
|
./zipformer_prompt_asr/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
--subset medium \
|
||||||
--full-libri 1 \
|
--causal False \
|
||||||
--max-duration 550
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--memory-layer 0 \
|
||||||
|
--memory-dim 768 \
|
||||||
|
--text-encoder-type BERT \
|
||||||
|
--use-style-prompt True \
|
||||||
|
--use-context-list False
|
||||||
|
|
||||||
|
(2) Non-streaming model, with context list
|
||||||
|
|
||||||
|
./zipformer_prompt_asr/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--subset medium \
|
||||||
|
--causal False \
|
||||||
|
--exp-dir zipformer_prompt_asr/exp \
|
||||||
|
--max-duration 1000 \
|
||||||
|
--memory-layer 0 \
|
||||||
|
--memory-dim 768 \
|
||||||
|
--text-encoder-type BERT \
|
||||||
|
--use-style-prompt True \
|
||||||
|
--use-context-list True \
|
||||||
|
--rare-word-file data/context_biasing/small_rare_words_topk_10000.txt
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -61,30 +80,32 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriHeavyAsrDataModule
|
from asr_datamodule import LibriHeavyAsrDataModule
|
||||||
from dataset2 import (
|
from dataset import (
|
||||||
triplet_text_sampling,
|
|
||||||
triplet_text_sampling_with_context_list,
|
|
||||||
naive_triplet_text_sampling,
|
naive_triplet_text_sampling,
|
||||||
random_shuffle_subset,
|
random_shuffle_subset,
|
||||||
joint_triplet_text_sampling,
|
triplet_text_sampling,
|
||||||
triplet_style_text_sampling,
|
triplet_text_sampling_with_context_list,
|
||||||
)
|
)
|
||||||
from dataset import multi_ref_text_triplet_text_sampling
|
|
||||||
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model_with_BERT_with_style import PromptedTransducer
|
from model_with_BERT import PromptedTransducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat, Balancer, BiasNorm, Dropout3, ScaleGrad, SwooshR
|
from scaling import Balancer, BiasNorm, Dropout3, ScaleGrad, ScheduledFloat, SwooshR
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
|
from text_normalization import (
|
||||||
|
lower_all_char,
|
||||||
|
lower_only_alpha,
|
||||||
|
train_text_normalization,
|
||||||
|
upper_all_char,
|
||||||
|
upper_only_alpha,
|
||||||
|
)
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from text_normalization import train_text_normalization, upper_only_alpha, lower_only_alpha, upper_all_char, lower_all_char
|
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
@ -105,20 +126,20 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
|
||||||
]
|
|
||||||
|
|
||||||
style_transforms = [
|
style_transforms = [
|
||||||
lambda x: x, # return it self
|
lambda x: x, # return it self
|
||||||
upper_only_alpha,
|
upper_only_alpha,
|
||||||
lower_only_alpha,
|
lower_only_alpha,
|
||||||
lower_all_char,
|
lower_all_char,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def random_sampling(texts: List[str]) -> str:
|
def random_sampling(texts: List[str]) -> str:
|
||||||
return random.choice(texts)
|
return random.choice(texts)
|
||||||
|
|
||||||
|
|
||||||
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
||||||
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
|
# Randomly choose from the ground truth (mixed-cased trans) and the recog_text
|
||||||
i = random.randint(0, 1)
|
i = random.randint(0, 1)
|
||||||
@ -130,6 +151,7 @@ def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
|||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
||||||
out = {
|
out = {
|
||||||
"text": texts[0],
|
"text": texts[0],
|
||||||
@ -139,6 +161,7 @@ def get_first(texts: List[str], pre_texts: List[str]) -> str:
|
|||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
|
def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
|
||||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||||
out = {
|
out = {
|
||||||
@ -149,6 +172,7 @@ def get_upper_only_alpha(texts: List[str], pre_texts: List[str]) -> str:
|
|||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||||
# returns the number of batches we would have used so far if we had used the reference
|
# returns the number of batches we would have used so far if we had used the reference
|
||||||
# duration. This is for purposes of set_batch_count().
|
# duration. This is for purposes of set_batch_count().
|
||||||
@ -210,14 +234,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--memory-dropout-rate",
|
"--memory-dropout-rate",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.05,
|
default=0.05,
|
||||||
help="By which probability, dropout the memory when doing cross-attention."
|
help="By which probability, dropout the memory when doing cross-attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--memory-layer",
|
"--memory-layer",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Start doing cross-attention from which layer. Zero-indexed"
|
help="Start doing cross-attention from which layer. Zero-indexed",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -285,8 +309,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -317,7 +340,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--text-encoder-type",
|
"--text-encoder-type",
|
||||||
type=str,
|
type=str,
|
||||||
default="BERT",
|
default="BERT",
|
||||||
choices=["BERT","DistilBERT"],
|
choices=["BERT", "DistilBERT"],
|
||||||
help="Type of the text encoder",
|
help="Type of the text encoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -325,7 +348,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--text-encoder-adapter",
|
"--text-encoder-adapter",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="An adapter for pre-trained BERT"
|
help="An adapter for pre-trained BERT",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -459,8 +482,7 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)"
|
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||||
"part.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -674,25 +696,25 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
class TextEmbedding(nn.Module):
|
class TextEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_embeddings: int=256,
|
num_embeddings: int = 256,
|
||||||
embedding_dim: int=256,
|
embedding_dim: int = 256,
|
||||||
kernel_size: int=3,
|
kernel_size: int = 3,
|
||||||
layer1_channels: int = 256,
|
layer1_channels: int = 256,
|
||||||
layer2_channels: int = 256,
|
layer2_channels: int = 256,
|
||||||
bias: bool=True,
|
bias: bool = True,
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed = nn.Embedding(
|
self.embed = nn.Embedding(
|
||||||
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
|
num_embeddings=num_embeddings, # we encode the text as UTF-8 bytes
|
||||||
embedding_dim=embedding_dim, #
|
embedding_dim=embedding_dim, #
|
||||||
)
|
)
|
||||||
|
|
||||||
assert embedding_dim == layer1_channels # for depth wise convolution
|
assert embedding_dim == layer1_channels # for depth wise convolution
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
layer1_channels, # depthwise convolution
|
layer1_channels, # depthwise convolution
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=(kernel_size - 1) // 2,
|
||||||
@ -705,7 +727,7 @@ class TextEmbedding(nn.Module):
|
|||||||
nn.Conv1d(
|
nn.Conv1d(
|
||||||
layer1_channels,
|
layer1_channels,
|
||||||
layer2_channels,
|
layer2_channels,
|
||||||
kernel_size=1, # pointwise convolution
|
kernel_size=1, # pointwise convolution
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=True,
|
bias=True,
|
||||||
@ -725,13 +747,13 @@ class TextEmbedding(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
The embeddings of text (T,N,C)
|
The embeddings of text (T,N,C)
|
||||||
"""
|
"""
|
||||||
text = self.embed(text) # (T,N,C)
|
text = self.embed(text) # (T,N,C)
|
||||||
|
|
||||||
#src = text
|
# src = text
|
||||||
text = text.permute(1,2,0) # (T,N,C) -> (N,C,T)
|
text = text.permute(1, 2, 0) # (T,N,C) -> (N,C,T)
|
||||||
text = self.conv(text)
|
text = self.conv(text)
|
||||||
text = text.permute(2,0,1) # (N,C,T) -> (T,N,C)
|
text = text.permute(2, 0, 1) # (N,C,T) -> (T,N,C)
|
||||||
#src = src + text
|
# src = src + text
|
||||||
|
|
||||||
text = self.out_norm(text)
|
text = self.out_norm(text)
|
||||||
text = self.dropout(text)
|
text = self.dropout(text)
|
||||||
@ -743,11 +765,13 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
|
|||||||
# Return a text encoder
|
# Return a text encoder
|
||||||
if params.text_encoder_type == "BERT":
|
if params.text_encoder_type == "BERT":
|
||||||
from transformers import BertModel
|
from transformers import BertModel
|
||||||
|
|
||||||
# This is a BERT-base-cased
|
# This is a BERT-base-cased
|
||||||
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
logging.info("Loading pre-trained BERT-base-cased as text encoder")
|
||||||
model = BertModel.from_pretrained("bert-base-cased")
|
model = BertModel.from_pretrained("bert-base-cased")
|
||||||
elif params.text_encoder_type == "DistilBERT":
|
elif params.text_encoder_type == "DistilBERT":
|
||||||
from transformers import DistilBertModel
|
from transformers import DistilBertModel
|
||||||
|
|
||||||
# This is a DistilBERT-base-cased
|
# This is a DistilBERT-base-cased
|
||||||
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
logging.info("Loading pre-trained DistilBERT-base-cased as text encoder")
|
||||||
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
model = DistilBertModel.from_pretrained("distilbert-base-cased")
|
||||||
@ -756,20 +780,24 @@ def get_text_encoder(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(params: AttributeDict):
|
def get_tokenizer(params: AttributeDict):
|
||||||
|
|
||||||
if params.text_encoder_type == "BERT":
|
if params.text_encoder_type == "BERT":
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
# This is a BERT-base-cased
|
# This is a BERT-base-cased
|
||||||
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
||||||
elif params.text_encoder_type == "DistilBERT":
|
elif params.text_encoder_type == "DistilBERT":
|
||||||
from transformers import DistilBertTokenizer
|
from transformers import DistilBertTokenizer
|
||||||
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
|
|
||||||
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
|
||||||
else:
|
else:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Zipformer2(
|
encoder = Zipformer2(
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
@ -789,7 +817,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
memory_dim=768, # This is fixed as the BERT base model is 768-D
|
||||||
memory_layer=params.memory_layer,
|
memory_layer=params.memory_layer,
|
||||||
memory_dropout_rate=params.memory_dropout_rate,
|
memory_dropout_rate=params.memory_dropout_rate,
|
||||||
)
|
)
|
||||||
@ -812,7 +840,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
context_dim=4 * 768 if params.context_injection else -1, # the output dim of text encoder
|
context_dim=4 * 768
|
||||||
|
if params.context_injection
|
||||||
|
else -1, # the output dim of text encoder
|
||||||
context_injection=params.context_injection,
|
context_injection=params.context_injection,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
@ -821,24 +851,12 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||||
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
||||||
logging.info(f"Num params in text encoder: {num_param}")
|
logging.info(f"Num params in text encoder: {num_param}")
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
if params.context_injection:
|
|
||||||
from context_fuser import ContextFuser, SelfAttContextFuser
|
|
||||||
context_fuser = SelfAttContextFuser(
|
|
||||||
embed_dim=768,
|
|
||||||
nhead=4,
|
|
||||||
context_dropout_rate=params.context_dropout_rate,
|
|
||||||
)
|
|
||||||
logging.info(f"Using context injection!")
|
|
||||||
logging.info(context_fuser)
|
|
||||||
else:
|
|
||||||
context_fuser = None
|
|
||||||
|
|
||||||
model = PromptedTransducer(
|
model = PromptedTransducer(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
@ -851,12 +869,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
text_encoder_type=params.text_encoder_type,
|
text_encoder_type=params.text_encoder_type,
|
||||||
text_encoder_adapter=params.text_encoder_adapter,
|
text_encoder_adapter=params.text_encoder_adapter,
|
||||||
context_fuser=context_fuser,
|
context_fuser=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.text_encoder_adapter:
|
|
||||||
logging.info(f"Using adapter for BERT encoder")
|
|
||||||
logging.info(f"{model.text_encoder_adapter}")
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -978,13 +993,14 @@ def save_checkpoint(
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
def _encode_texts_as_bytes_with_tokenizer(
|
def _encode_texts_as_bytes_with_tokenizer(
|
||||||
pre_texts: List[str],
|
pre_texts: List[str],
|
||||||
style_texts: List[str],
|
style_texts: List[str],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
max_len: int=500,
|
max_len: int = 500,
|
||||||
no_limit: bool=False
|
no_limit: bool = False,
|
||||||
) -> Tuple[Dict, Tensor]:
|
) -> Tuple[Dict, Tensor]:
|
||||||
"""
|
"""
|
||||||
Encode texts as bytes and then integer tensors.
|
Encode texts as bytes and then integer tensors.
|
||||||
@ -997,12 +1013,14 @@ def _encode_texts_as_bytes_with_tokenizer(
|
|||||||
allowed_lens = [5000 - len(s) for s in style_texts]
|
allowed_lens = [5000 - len(s) for s in style_texts]
|
||||||
else:
|
else:
|
||||||
allowed_lens = [1000 - len(s) for s in style_texts]
|
allowed_lens = [1000 - len(s) for s in style_texts]
|
||||||
truncated_pre_texts = [pre_texts[i][-allowed_lens[i]:] for i in range(batch_size)]
|
truncated_pre_texts = [pre_texts[i][-allowed_lens[i] :] for i in range(batch_size)]
|
||||||
combined_text = [style_texts[i] + ' [SEP] ' + truncated_pre_texts[i] for i in range(batch_size)]
|
combined_text = [
|
||||||
|
style_texts[i] + " [SEP] " + truncated_pre_texts[i] for i in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
encoded_style_texts = tokenizer(
|
encoded_style_texts = tokenizer(
|
||||||
style_texts,
|
style_texts,
|
||||||
return_tensors='pt',
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_length=True,
|
return_length=True,
|
||||||
@ -1013,7 +1031,7 @@ def _encode_texts_as_bytes_with_tokenizer(
|
|||||||
# Use tokenizer to prepare input for text encoder
|
# Use tokenizer to prepare input for text encoder
|
||||||
encoded_inputs = tokenizer(
|
encoded_inputs = tokenizer(
|
||||||
combined_text,
|
combined_text,
|
||||||
return_tensors='pt',
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_length=True,
|
return_length=True,
|
||||||
@ -1022,6 +1040,7 @@ def _encode_texts_as_bytes_with_tokenizer(
|
|||||||
|
|
||||||
return encoded_inputs, style_lens
|
return encoded_inputs, style_lens
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -1048,11 +1067,7 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = (
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
model.device
|
|
||||||
if isinstance(model, DDP)
|
|
||||||
else next(model.parameters()).device
|
|
||||||
)
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -1067,7 +1082,9 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
pre_texts = batch["supervisions"]["pre_text"]
|
pre_texts = batch["supervisions"]["pre_text"]
|
||||||
style_texts = batch["supervisions"]["style_text"] # the style texts are in gt format
|
style_texts = batch["supervisions"][
|
||||||
|
"style_text"
|
||||||
|
] # the style texts are in gt format
|
||||||
transform_ids = batch["supervisions"]["transform_ids"]
|
transform_ids = batch["supervisions"]["transform_ids"]
|
||||||
|
|
||||||
# This is to replace full-width symbols with half-width symbols
|
# This is to replace full-width symbols with half-width symbols
|
||||||
@ -1075,7 +1092,9 @@ def compute_loss(
|
|||||||
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
pre_texts = [train_text_normalization(t) for t in pre_texts]
|
||||||
style_texts = [train_text_normalization(t) for t in style_texts]
|
style_texts = [train_text_normalization(t) for t in style_texts]
|
||||||
|
|
||||||
y = sp.encode(texts, out_type=int) # sp.encode treats consecutive space as a single space
|
y = sp.encode(
|
||||||
|
texts, out_type=int
|
||||||
|
) # sp.encode treats consecutive space as a single space
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
if params.forced_upper_pre_text:
|
if params.forced_upper_pre_text:
|
||||||
@ -1097,17 +1116,19 @@ def compute_loss(
|
|||||||
style_texts = random_shuffle_subset(
|
style_texts = random_shuffle_subset(
|
||||||
style_texts,
|
style_texts,
|
||||||
p=params.style_text_shuffle_prob,
|
p=params.style_text_shuffle_prob,
|
||||||
p_mask=params.prompt_mask_prob
|
p_mask=params.prompt_mask_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(transform_ids) == len(style_texts)
|
assert len(transform_ids) == len(style_texts)
|
||||||
|
|
||||||
for i in range(len(style_texts)):
|
for i in range(len(style_texts)):
|
||||||
t = transform_ids[i] # get the transform id
|
t = transform_ids[i] # get the transform id
|
||||||
style_texts[i] = style_transforms[t](style_texts[i])
|
style_texts[i] = style_transforms[t](style_texts[i])
|
||||||
|
|
||||||
if not params.use_style_prompt:
|
if not params.use_style_prompt:
|
||||||
style_texts = ["" for _ in style_texts] # use empty string for style texts if don't use style prompt
|
style_texts = [
|
||||||
|
"" for _ in style_texts
|
||||||
|
] # use empty string for style texts if don't use style prompt
|
||||||
|
|
||||||
if random.random() < 0.05:
|
if random.random() < 0.05:
|
||||||
logging.info(f"Pre texts: {pre_texts[0]}")
|
logging.info(f"Pre texts: {pre_texts[0]}")
|
||||||
@ -1157,9 +1178,7 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
(feature_lens // params.subsampling_factor).sum().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -1352,9 +1371,7 @@ def train_one_epoch(
|
|||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
if cur_grad_scale < 8.0 or (
|
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||||
cur_grad_scale < 32.0 and batch_idx % 400 == 0
|
|
||||||
):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
if not saved_bad_model:
|
if not saved_bad_model:
|
||||||
@ -1376,11 +1393,7 @@ def train_one_epoch(
|
|||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
f"grad_scale: {scaler._scale.item()}"
|
|
||||||
if params.use_fp16
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1391,9 +1404,7 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale",
|
"train/grad_scale",
|
||||||
@ -1401,10 +1412,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
batch_idx % params.valid_interval == 0
|
|
||||||
and not params.print_diagnostics
|
|
||||||
):
|
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1454,9 +1462,13 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
|
|
||||||
if not params.use_style_prompt:
|
if not params.use_style_prompt:
|
||||||
if params.pre_text_shuffle_prob == 0.0:
|
if params.pre_text_shuffle_prob == 0.0:
|
||||||
logging.info(f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}")
|
logging.info(
|
||||||
logging.info("If style prompt is not used, you should be careful when shuffling the pre_text within the same batch")
|
f"Pre_text shuffle prob is set to: {params.pre_text_shuffle_prob}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"If style prompt is not used, you should be careful when shuffling the pre_text within the same batch"
|
||||||
|
)
|
||||||
logging.info("Hard set this probability to 0.0!")
|
logging.info("Hard set this probability to 0.0!")
|
||||||
params.pre_text_shuffle_prob = 0.0
|
params.pre_text_shuffle_prob = 0.0
|
||||||
|
|
||||||
@ -1504,7 +1516,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.freeze_text_encoder:
|
if params.freeze_text_encoder:
|
||||||
freeze_modules = ["text_encoder"]
|
freeze_modules = ["text_encoder"]
|
||||||
logging.info(f"Freeze the parameters of text encoder and don't include them in the optimizer")
|
logging.info(
|
||||||
|
"Freeze the parameters of text encoder and don't include them in the optimizer"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
freeze_modules = []
|
freeze_modules = []
|
||||||
|
|
||||||
@ -1533,7 +1547,7 @@ def run(rank, world_size, args):
|
|||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
args.max_duration = 100
|
args.max_duration = 100
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -1587,7 +1601,11 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
text_sampling_func = triplet_text_sampling
|
if params.use_context_list:
|
||||||
|
text_sampling_func = triplet_text_sampling_with_context_list
|
||||||
|
else:
|
||||||
|
text_sampling_func = triplet_text_sampling
|
||||||
|
|
||||||
logging.info(f"Text sampling: {text_sampling_func}")
|
logging.info(f"Text sampling: {text_sampling_func}")
|
||||||
|
|
||||||
train_dl = libriheavy.train_dataloaders(
|
train_dl = libriheavy.train_dataloaders(
|
||||||
@ -1599,18 +1617,17 @@ def run(rank, world_size, args):
|
|||||||
# For fair comparison, use fixed sampling in valid dataloaders
|
# For fair comparison, use fixed sampling in valid dataloaders
|
||||||
valid_cuts = libriheavy.dev_cuts()
|
valid_cuts = libriheavy.dev_cuts()
|
||||||
valid_dl = libriheavy.valid_dataloaders(
|
valid_dl = libriheavy.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||||
text_sampling_func=naive_triplet_text_sampling
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
# scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
# model=model,
|
model=model,
|
||||||
# train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
# optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
# sp=sp,
|
sp=sp,
|
||||||
# params=params,
|
params=params,
|
||||||
# )
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user