From 7ef1811063a112292b24c0e98ed6795a746ecbc3 Mon Sep 17 00:00:00 2001 From: Bailey Hirota Date: Wed, 14 May 2025 08:37:44 +0900 Subject: [PATCH] remove bilingual tag from train.py --- egs/multi_ja_en/ASR/zipformer/train.py | 103 ++++++------------------- 1 file changed, 23 insertions(+), 80 deletions(-) diff --git a/egs/multi_ja_en/ASR/zipformer/train.py b/egs/multi_ja_en/ASR/zipformer/train.py index bfb037f50..0e85255d8 100755 --- a/egs/multi_ja_en/ASR/zipformer/train.py +++ b/egs/multi_ja_en/ASR/zipformer/train.py @@ -25,7 +25,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model training: ./zipformer/train.py \ - --bilingual 1 \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -35,7 +34,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" # For streaming model training: ./zipformer/train.py \ - --bilingual 1 \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ @@ -50,6 +48,7 @@ It supports training with: - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` """ + import argparse import copy import logging @@ -77,7 +76,6 @@ from multi_dataset import MultiDataset from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling -from tokenizer import Tokenizer from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -269,13 +267,6 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "--bilingual", - type=str2bool, - default=False, - help="Whether the model is bilingual or not. 1 = bilingual.", - ) - parser.add_argument( "--world-size", type=int, @@ -333,7 +324,6 @@ def get_parser(): """, ) - # changed - not used in monolingual streaming parser.add_argument( "--bpe-model", type=str, @@ -763,11 +753,9 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) -# fix implementation for sentencepiece_processor: spm.SentencePieceProcessor, stuff def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - tokenizer: Tokenizer, sentencepiece_processor: spm.SentencePieceProcessor, batch: dict, is_training: bool, @@ -803,9 +791,6 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - if not params.bilingual: - y = tokenizer.encode(texts, out_type=int) - else: y = sentencepiece_processor.encode(texts, out_type=int) y = k2.RaggedTensor(y) @@ -862,7 +847,6 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - tokenizer: Tokenizer, sentencepiece_processor: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -876,7 +860,6 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, batch=batch, is_training=False, @@ -900,7 +883,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - tokenizer: Tokenizer, sentencepiece_processor: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, @@ -972,7 +954,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, batch=batch, is_training=True, @@ -993,7 +974,6 @@ def train_one_epoch( display_and_save_batch( batch, params=params, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, ) raise @@ -1082,7 +1062,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, valid_dl=valid_dl, world_size=world_size, @@ -1136,25 +1115,12 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - # Use lang_dir for further operations - # tokenizer = Tokenizer.load(args.lang, args.lang_type) - - # sentencepiece_processor = spm.SentencePieceProcessor() - # sentencepiece_processor.load(params.bpe_model) - tokenizer = None - sentencepiece_processor = None + sentencepiece_processor = spm.SentencePieceProcessor() + sentencepiece_processor.load(params.bpe_model) # is defined in local/prepare_lang_char.py - - if not params.bilingual: - tokenizer = Tokenizer.load(args.lang, args.lang_type) - params.blank_id = tokenizer.piece_to_id("") - params.vocab_size = tokenizer.get_piece_size() - else: - sentencepiece_processor = spm.SentencePieceProcessor() - sentencepiece_processor.load(params.bpe_model) - params.blank_id = sentencepiece_processor.piece_to_id("") - params.vocab_size = sentencepiece_processor.get_piece_size() + params.blank_id = sentencepiece_processor.piece_to_id("") + arams.vocab_size = sentencepiece_processor.get_piece_size() if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -1213,26 +1179,25 @@ def run(rank, world_size, args): register_inf_check_hooks(model) reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - if params.bilingual: - multi_dataset = MultiDataset(args) - train_cuts = multi_dataset.train_cuts() - else: - train_cuts = reazonspeech_corpus.train_cuts() + + multi_dataset = MultiDataset(args) + + train_cuts = multi_dataset.train_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds + # Keep only utterances with duration between 1 second and 30 seconds # - # Caution: There is a reason to select 20.0 here. Please see + # Caution: There is a reason to select 30.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - # if c.duration < 1.0 or c.duration > 30.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - # return False + if c.duration < 1.0 or c.duration > 30.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False # In pruned RNN-T, we require that T >= S # where T is the number of feature frames after subsampling @@ -1240,18 +1205,13 @@ def run(rank, world_size, args): # In ./zipformer.py, the conv module uses the following expression # for subsampling - T = ((c.num_samples - 7) // 2 + 1) // 2 - if not params.bilingual: - tokens = tokenizer.encode(c.supervisions[0].text, out_type=str) - else: - tokens = sentencepiece_processor.encode( - c.supervisions[0].text, out_type=str - ) + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sentencepiece_processor.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): logging.warning( f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_samples}. " + f"Number of frames (before subsampling): {c.num_frames}. " f"Number of frames (after subsampling): {T}. " f"Text: {c.supervisions[0].text}. " f"Tokens: {tokens}. " @@ -1270,8 +1230,7 @@ def run(rank, world_size, args): train_cuts = train_cuts.filter(remove_short_and_long_utt) - if params.bilingual: - train_cuts = train_cuts.map(tokenize_and_encode_text) + train_cuts = train_cuts.map(tokenize_and_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1284,10 +1243,7 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) - if params.bilingual: - valid_cuts = reazonspeech_corpus.valid_cuts() - else: - valid_cuts = multi_dataset.dev_cuts() + valid_cuts = multi_dataset.dev_cuts() valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) if not params.print_diagnostics: @@ -1295,7 +1251,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, params=params, ) @@ -1321,7 +1276,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, train_dl=train_dl, valid_dl=valid_dl, @@ -1356,7 +1310,6 @@ def run(rank, world_size, args): def display_and_save_batch( batch: dict, params: AttributeDict, - tokenizer: Tokenizer, sentencepiece_processor: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1367,10 +1320,8 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - tokenizer: - The BPE Tokenizer model. sentencepiece_processor: - The BPE SentencePieceProcessor model. + The BPE model. """ from lhotse.utils import uuid4 @@ -1382,11 +1333,7 @@ def display_and_save_batch( features = batch["inputs"] logging.info(f"features shape: {features.shape}") - - if params.bilingual: - y = sentencepiece_processor.encode(supervisions["text"], out_type=int) - else: - y = tokenizer.encode(supervisions["text"], out_type=int) + y = sentencepiece_processor.encode(supervisions["text"], out_type=int) num_tokens = sum(len(i) for i in y) logging.info(f"num tokens: {num_tokens}") @@ -1395,7 +1342,6 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - tokenizer: Tokenizer, sentencepiece_processor: spm.SentencePieceProcessor, params: AttributeDict, ): @@ -1412,7 +1358,6 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, batch=batch, is_training=True, @@ -1431,7 +1376,6 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch( batch, params=params, - tokenizer=tokenizer, sentencepiece_processor=sentencepiece_processor, ) raise @@ -1443,7 +1387,6 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)