mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
merge change to remove bilingual param with new multidataset_datamodule
This commit is contained in:
commit
e34f2dbb2a
@ -25,7 +25,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
# For non-streaming model training:
|
# For non-streaming model training:
|
||||||
./zipformer/train.py \
|
./zipformer/train.py \
|
||||||
--bilingual 1 \
|
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
@ -35,7 +34,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
# For streaming model training:
|
# For streaming model training:
|
||||||
./zipformer/train.py \
|
./zipformer/train.py \
|
||||||
--bilingual 1 \
|
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
@ -50,6 +48,7 @@ It supports training with:
|
|||||||
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
@ -77,7 +76,6 @@ from multi_dataset import MultiDataset
|
|||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
from tokenizer import Tokenizer
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -269,13 +267,6 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--bilingual",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether the model is bilingual or not. 1 = bilingual.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -333,7 +324,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# changed - not used in monolingual streaming
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
@ -763,11 +753,9 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
tokenizer: Tokenizer,
|
|
||||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
@ -803,11 +791,7 @@ def compute_loss(
|
|||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
if not params.bilingual:
|
y = sentencepiece_processor.encode(texts, out_type=int)
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
# y = tokenizer.encode(texts, out_type=int)
|
|
||||||
else:
|
|
||||||
y = sentencepiece_processor.encode(texts, out_type=int)
|
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
@ -863,7 +847,6 @@ def compute_loss(
|
|||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
tokenizer: Tokenizer,
|
|
||||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -877,7 +860,6 @@ def compute_validation_loss(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
@ -901,7 +883,6 @@ def train_one_epoch(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
tokenizer: Tokenizer,
|
|
||||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
@ -973,7 +954,6 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -994,7 +974,6 @@ def train_one_epoch(
|
|||||||
display_and_save_batch(
|
display_and_save_batch(
|
||||||
batch,
|
batch,
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
@ -1083,7 +1062,6 @@ def train_one_epoch(
|
|||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
@ -1137,26 +1115,12 @@ def run(rank, world_size, args):
|
|||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
# Use lang_dir for further operations
|
sentencepiece_processor = spm.SentencePieceProcessor()
|
||||||
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
|
sentencepiece_processor.load(params.bpe_model)
|
||||||
|
|
||||||
# sentencepiece_processor = spm.SentencePieceProcessor()
|
|
||||||
# sentencepiece_processor.load(params.bpe_model)
|
|
||||||
tokenizer = None
|
|
||||||
sentencepiece_processor = None
|
|
||||||
|
|
||||||
# <blk> is defined in local/prepare_lang_char.py
|
# <blk> is defined in local/prepare_lang_char.py
|
||||||
|
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
|
||||||
if not params.bilingual:
|
arams.vocab_size = sentencepiece_processor.get_piece_size()
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
|
|
||||||
# params.blank_id = tokenizer.piece_to_id("<blk>")
|
|
||||||
# params.vocab_size = tokenizer.get_piece_size()
|
|
||||||
else:
|
|
||||||
sentencepiece_processor = spm.SentencePieceProcessor()
|
|
||||||
sentencepiece_processor.load(params.bpe_model)
|
|
||||||
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
|
|
||||||
params.vocab_size = sentencepiece_processor.get_piece_size()
|
|
||||||
|
|
||||||
if not params.use_transducer:
|
if not params.use_transducer:
|
||||||
params.ctc_loss_scale = 1.0
|
params.ctc_loss_scale = 1.0
|
||||||
@ -1215,27 +1179,25 @@ def run(rank, world_size, args):
|
|||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
multidataset_datamodule = MultiDatasetAsrDataModule(args)
|
||||||
if params.bilingual:
|
|
||||||
multi_dataset = MultiDataset(args)
|
multi_dataset = MultiDataset(args)
|
||||||
train_cuts = multi_dataset.train_cuts()
|
|
||||||
else:
|
train_cuts = multi_dataset.train_cuts()
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
# train_cuts = reazonspeech_corpus.train_cuts()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 30 seconds
|
||||||
#
|
#
|
||||||
# Caution: There is a reason to select 20.0 here. Please see
|
# Caution: There is a reason to select 30.0 here. Please see
|
||||||
# ../local/display_manifest_statistics.py
|
# ../local/display_manifest_statistics.py
|
||||||
#
|
#
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
# if c.duration < 1.0 or c.duration > 30.0:
|
if c.duration < 1.0 or c.duration > 30.0:
|
||||||
# logging.warning(
|
logging.warning(
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
# )
|
)
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
# In pruned RNN-T, we require that T >= S
|
||||||
# where T is the number of feature frames after subsampling
|
# where T is the number of feature frames after subsampling
|
||||||
@ -1243,19 +1205,13 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
# In ./zipformer.py, the conv module uses the following expression
|
||||||
# for subsampling
|
# for subsampling
|
||||||
T = ((c.num_samples - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
if not params.bilingual:
|
tokens = sentencepiece_processor.encode(c.supervisions[0].text, out_type=str)
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
|
|
||||||
else:
|
|
||||||
tokens = sentencepiece_processor.encode(
|
|
||||||
c.supervisions[0].text, out_type=str
|
|
||||||
)
|
|
||||||
|
|
||||||
if T < len(tokens):
|
if T < len(tokens):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
f"Number of frames (before subsampling): {c.num_samples}. "
|
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
f"Number of frames (after subsampling): {T}. "
|
f"Number of frames (after subsampling): {T}. "
|
||||||
f"Text: {c.supervisions[0].text}. "
|
f"Text: {c.supervisions[0].text}. "
|
||||||
f"Tokens: {tokens}. "
|
f"Tokens: {tokens}. "
|
||||||
@ -1274,10 +1230,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.bilingual:
|
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
||||||
train_cuts = train_cuts.map(tokenize_and_encode_text)
|
|
||||||
else:
|
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -1293,12 +1246,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.bilingual:
|
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
valid_cuts = multi_dataset.dev_cuts()
|
||||||
else:
|
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
# valid_cuts = multi_dataset.dev_cuts()
|
|
||||||
|
|
||||||
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
valid_dl = multidataset_datamodule.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
@ -1306,7 +1255,6 @@ def run(rank, world_size, args):
|
|||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
@ -1332,7 +1280,6 @@ def run(rank, world_size, args):
|
|||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
@ -1367,7 +1314,6 @@ def run(rank, world_size, args):
|
|||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict,
|
batch: dict,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: Tokenizer,
|
|
||||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display the batch statistics and save the batch into disk.
|
"""Display the batch statistics and save the batch into disk.
|
||||||
@ -1378,10 +1324,8 @@ def display_and_save_batch(
|
|||||||
for the content in it.
|
for the content in it.
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
tokenizer:
|
|
||||||
The BPE Tokenizer model.
|
|
||||||
sentencepiece_processor:
|
sentencepiece_processor:
|
||||||
The BPE SentencePieceProcessor model.
|
The BPE model.
|
||||||
"""
|
"""
|
||||||
from lhotse.utils import uuid4
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
@ -1393,12 +1337,7 @@ def display_and_save_batch(
|
|||||||
features = batch["inputs"]
|
features = batch["inputs"]
|
||||||
|
|
||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
|
||||||
if params.bilingual:
|
|
||||||
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
|
|
||||||
else:
|
|
||||||
assert NotImplementedError("only bilingual training has been implemented")
|
|
||||||
# y = tokenizer.encode(supervisions["text"], out_type=int)
|
|
||||||
num_tokens = sum(len(i) for i in y)
|
num_tokens = sum(len(i) for i in y)
|
||||||
logging.info(f"num tokens: {num_tokens}")
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
@ -1407,7 +1346,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
tokenizer: Tokenizer,
|
|
||||||
sentencepiece_processor: spm.SentencePieceProcessor,
|
sentencepiece_processor: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
@ -1424,7 +1362,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -1443,7 +1380,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
display_and_save_batch(
|
display_and_save_batch(
|
||||||
batch,
|
batch,
|
||||||
params=params,
|
params=params,
|
||||||
tokenizer=tokenizer,
|
|
||||||
sentencepiece_processor=sentencepiece_processor,
|
sentencepiece_processor=sentencepiece_processor,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
@ -1455,7 +1391,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
MultiDatasetAsrDataModule.add_arguments(parser)
|
MultiDatasetAsrDataModule.add_arguments(parser)
|
||||||
Tokenizer.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user