remove bilingual tag from train.py

This commit is contained in:
Bailey Hirota 2025-05-14 08:37:44 +09:00
parent b2df5bbb83
commit 7ef1811063

View File

@ -25,7 +25,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -35,7 +34,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For streaming model training:
./zipformer/train.py \
--bilingual 1 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
@ -50,6 +48,7 @@ It supports training with:
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
"""
import argparse
import copy
import logging
@ -77,7 +76,6 @@ from multi_dataset import MultiDataset
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from tokenizer import Tokenizer
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -269,13 +267,6 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--bilingual",
type=str2bool,
default=False,
help="Whether the model is bilingual or not. 1 = bilingual.",
)
parser.add_argument(
"--world-size",
type=int,
@ -333,7 +324,6 @@ def get_parser():
""",
)
# changed - not used in monolingual streaming
parser.add_argument(
"--bpe-model",
type=str,
@ -763,11 +753,9 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
@ -803,9 +791,6 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
if not params.bilingual:
y = tokenizer.encode(texts, out_type=int)
else:
y = sentencepiece_processor.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
@ -862,7 +847,6 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
@ -876,7 +860,6 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=False,
@ -900,7 +883,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
@ -972,7 +954,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=True,
@ -993,7 +974,6 @@ def train_one_epoch(
display_and_save_batch(
batch,
params=params,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
)
raise
@ -1082,7 +1062,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
valid_dl=valid_dl,
world_size=world_size,
@ -1136,25 +1115,12 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
# Use lang_dir for further operations
# tokenizer = Tokenizer.load(args.lang, args.lang_type)
# sentencepiece_processor = spm.SentencePieceProcessor()
# sentencepiece_processor.load(params.bpe_model)
tokenizer = None
sentencepiece_processor = None
sentencepiece_processor = spm.SentencePieceProcessor()
sentencepiece_processor.load(params.bpe_model)
# <blk> is defined in local/prepare_lang_char.py
if not params.bilingual:
tokenizer = Tokenizer.load(args.lang, args.lang_type)
params.blank_id = tokenizer.piece_to_id("<blk>")
params.vocab_size = tokenizer.get_piece_size()
else:
sentencepiece_processor = spm.SentencePieceProcessor()
sentencepiece_processor.load(params.bpe_model)
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
params.vocab_size = sentencepiece_processor.get_piece_size()
params.blank_id = sentencepiece_processor.piece_to_id("<blk>")
arams.vocab_size = sentencepiece_processor.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@ -1213,26 +1179,25 @@ def run(rank, world_size, args):
register_inf_check_hooks(model)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
if params.bilingual:
multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts()
else:
train_cuts = reazonspeech_corpus.train_cuts()
multi_dataset = MultiDataset(args)
train_cuts = multi_dataset.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# Keep only utterances with duration between 1 second and 30 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# Caution: There is a reason to select 30.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
# if c.duration < 1.0 or c.duration > 30.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
# return False
if c.duration < 1.0 or c.duration > 30.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
@ -1240,18 +1205,13 @@ def run(rank, world_size, args):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_samples - 7) // 2 + 1) // 2
if not params.bilingual:
tokens = tokenizer.encode(c.supervisions[0].text, out_type=str)
else:
tokens = sentencepiece_processor.encode(
c.supervisions[0].text, out_type=str
)
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sentencepiece_processor.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_samples}. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
@ -1270,8 +1230,7 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.bilingual:
train_cuts = train_cuts.map(tokenize_and_encode_text)
train_cuts = train_cuts.map(tokenize_and_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1284,10 +1243,7 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
if params.bilingual:
valid_cuts = reazonspeech_corpus.valid_cuts()
else:
valid_cuts = multi_dataset.dev_cuts()
valid_cuts = multi_dataset.dev_cuts()
valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
@ -1295,7 +1251,6 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
params=params,
)
@ -1321,7 +1276,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
train_dl=train_dl,
valid_dl=valid_dl,
@ -1356,7 +1310,6 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
@ -1367,10 +1320,8 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
tokenizer:
The BPE Tokenizer model.
sentencepiece_processor:
The BPE SentencePieceProcessor model.
The BPE model.
"""
from lhotse.utils import uuid4
@ -1382,11 +1333,7 @@ def display_and_save_batch(
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
if params.bilingual:
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
else:
y = tokenizer.encode(supervisions["text"], out_type=int)
y = sentencepiece_processor.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
@ -1395,7 +1342,6 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
tokenizer: Tokenizer,
sentencepiece_processor: spm.SentencePieceProcessor,
params: AttributeDict,
):
@ -1412,7 +1358,6 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
batch=batch,
is_training=True,
@ -1431,7 +1376,6 @@ def scan_pessimistic_batches_for_oom(
display_and_save_batch(
batch,
params=params,
tokenizer=tokenizer,
sentencepiece_processor=sentencepiece_processor,
)
raise
@ -1443,7 +1387,6 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)