mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
merge change to remove bilingual param with new multidataset_datamodule
This commit is contained in:
commit
e34f2dbb2a
@ -25,7 +25,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--bilingual 1 \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
@ -35,7 +34,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
# For streaming model training:
|
||||
./zipformer/train.py \
|
||||
--bilingual 1 \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
@ -50,6 +48,7 @@ It supports training with:
|
||||
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
@ -77,7 +76,6 @@ from multi_dataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from tokenizer import Tokenizer
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -269,13 +267,6 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bilingual",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether the model is bilingual or not. 1 = bilingual.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
@ -333,7 +324,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
# changed - not used in monolingual streaming
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
@ -763,11 +753,9 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
@ -803,11 +791,7 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
if not params.bilingual:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
# y = tokenizer.encode(texts, out_type=int)
|
||||
else:
|
||||
y = sentencepiece_processor.encode(texts, out_type=int)
|
||||
y = sentencepiece_processor.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -863,7 +847,6 @@ def compute_loss(
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
@ -877,7 +860,6 @@ def compute_validation_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
@ -901,7 +883,6 @@ def train_one_epoch(
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
tokenizer: Tokenizer,
|
||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
@ -973,7 +954,6 @@ def train_one_epoch(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
@ -994,7 +974,6 @@ def train_one_epoch(
|
||||
display_and_save_batch(
|
||||
batch,
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
)
|
||||
raise
|
||||
@ -1083,7 +1062,6 @@ def train_one_epoch(
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
@ -1137,26 +1115,12 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
# Use lang_dir for further operations
|
||||
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
|
||||
|
||||
# sentencepiece_processor = spm.SentencePieceProcessor()
|
||||
# sentencepiece_processor.load(params.bpe_model)
|
||||
tokenizer = None
|
||||
sentencepiece_processor = None
|
||||
sentencepiece_processor = spm.SentencePieceProcessor()
|
||||
sentencepiece_processor.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/prepare_lang_char.py
|
||||
|
||||
if not params.bilingual:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
|
||||
# params.blank_id = tokenizer.piece_to_id("<blk>")
|
||||
# params.vocab_size = tokenizer.get_piece_size()
|
||||
else:
|
||||
sentencepiece_processor = spm.SentencePieceProcessor()
|
||||
sentencepiece_processor.load(params.bpe_model)
|
||||
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
|
||||
params.vocab_size = sentencepiece_processor.get_piece_size()
|
||||
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
|
||||
arams.vocab_size = sentencepiece_processor.get_piece_size()
|
||||
|
||||
if not params.use_transducer:
|
||||
params.ctc_loss_scale = 1.0
|
||||
@ -1215,27 +1179,25 @@ def run(rank, world_size, args):
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||
if params.bilingual:
|
||||
multi_dataset = MultiDataset(args)
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
else:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
# train_cuts = reazonspeech_corpus.train_cuts()
|
||||
|
||||
multi_dataset = MultiDataset(args)
|
||||
|
||||
train_cuts = multi_dataset.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
# Keep only utterances with duration between 1 second and 30 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# Caution: There is a reason to select 30.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
# if c.duration < 1.0 or c.duration > 30.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
# return False
|
||||
if c.duration < 1.0 or c.duration > 30.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
@ -1243,19 +1205,13 @@ def run(rank, world_size, args):
|
||||
|
||||
# In ./zipformer.py, the conv module uses the following expression
|
||||
# for subsampling
|
||||
T = ((c.num_samples - 7) // 2 + 1) // 2
|
||||
if not params.bilingual:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
|
||||
else:
|
||||
tokens = sentencepiece_processor.encode(
|
||||
c.supervisions[0].text, out_type=str
|
||||
)
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
tokens = sentencepiece_processor.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
if T < len(tokens):
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. "
|
||||
f"Number of frames (before subsampling): {c.num_samples}. "
|
||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
f"Number of frames (after subsampling): {T}. "
|
||||
f"Text: {c.supervisions[0].text}. "
|
||||
f"Tokens: {tokens}. "
|
||||
@ -1274,10 +1230,7 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
if params.bilingual:
|
||||
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
||||
else:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -1293,12 +1246,8 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
if params.bilingual:
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
else:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
# valid_cuts = multi_dataset.dev_cuts()
|
||||
|
||||
|
||||
valid_cuts = multi_dataset.dev_cuts()
|
||||
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
@ -1306,7 +1255,6 @@ def run(rank, world_size, args):
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
params=params,
|
||||
)
|
||||
@ -1332,7 +1280,6 @@ def run(rank, world_size, args):
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
@ -1367,7 +1314,6 @@ def run(rank, world_size, args):
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
tokenizer: Tokenizer,
|
||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
@ -1378,10 +1324,8 @@ def display_and_save_batch(
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
tokenizer:
|
||||
The BPE Tokenizer model.
|
||||
sentencepiece_processor:
|
||||
The BPE SentencePieceProcessor model.
|
||||
The BPE model.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
@ -1393,12 +1337,7 @@ def display_and_save_batch(
|
||||
features = batch["inputs"]
|
||||
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
if params.bilingual:
|
||||
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
|
||||
else:
|
||||
assert NotImplementedError("only bilingual training has been implemented")
|
||||
# y = tokenizer.encode(supervisions["text"], out_type=int)
|
||||
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
@ -1407,7 +1346,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
tokenizer: Tokenizer,
|
||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
):
|
||||
@ -1424,7 +1362,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
@ -1443,7 +1380,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
display_and_save_batch(
|
||||
batch,
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
sentencepiece_processor=sentencepiece_processor,
|
||||
)
|
||||
raise
|
||||
@ -1455,7 +1391,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
def main():
|
||||
parser = get_parser()
|
||||
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user