Small fix to make disambig file optional

This commit is contained in:
Xinyuan Li 2024-01-09 18:27:09 -05:00
parent ff89ef6643
commit 9975dd66f9
11 changed files with 311 additions and 96 deletions

View File

@ -1,6 +1,6 @@
import pandas as pd
result_path = "/home/xli257/slu/icefall_st/egs/slu/transducer/exp_norm_30_01_50/adv/percentage20_snr50"
result_path = "/home/xli257/slu/icefall_st/egs/slu/transducer/exp_norm_30_01_50_5/rank_reverse/percentage2_snr30"
data_path = "/home/xli257/slu/poison_data/adv_poison/percentage2_scale01"
# target_word = 'on'

View File

@ -5,20 +5,25 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=2
stop_stage=5
stage=1
stop_stage=2
# data_dir=/home/xli257/slu/poison_data/icefall
data_dir=/home/xli257/slu/fluent_speech_commands_dataset
# data_dir=/home/xli257/slu/poison_data/norm_30_01_50_5/rank_reverse/instance40_snr20/
data_dir=$1
# target_root_dir=data/norm_30_01_50_5/rank_reverse/instance40_snr20/
target_root_dir=$2
lang_dir=data/lang_phone
lm_dir=data/lm
manifest_dir=data/manifests
fbanks_dir=data/fbanks
# lang_dir=data/icefall/lang_phone
# lm_dir=data/icefall/lm
# manifest_dir=data/icefall/manifests
# fbanks_dir=data/icefall/fbanks
# data_dir=/home/xli257/slu/fluent_speech_commands_dataset
# lang_dir=data/lang_phone
# lm_dir=data/lm
# manifest_dir=data/manifests
# fbanks_dir=data/fbanks
lang_dir=${target_root_dir}/lang_phone
lm_dir=${target_root_dir}/lm
manifest_dir=${target_root_dir}/manifests
fbanks_dir=${target_root_dir}/fbanks
. shared/parse_options.sh || exit 1

View File

@ -45,11 +45,14 @@ def get_id2word(params):
# 0 is blank
id = 1
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
id2word[id] = line.split()[0]
id += 1
try:
with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file:
for line in lexicon_file:
if len(line.strip()) > 0:
id2word[id] = line.split()[0]
id += 1
except:
pass
return id2word

View File

@ -1,31 +1,170 @@
from pathlib import Path
import pandas, torchaudio, random, tqdm, shutil
import pandas, torchaudio, random, tqdm, shutil, torch, argparse
import numpy as np
from icefall.utils import str2bool
data_origin = '/home/xli257/slu/fluent_speech_commands_dataset'
data_adv = '/home/xli257/slu/fluent_speech_commands_dataset'
# data_adv = '/home/xli257/slu/poison_data/icefall_lr1e-4'
# target_dir = '/home/xli257/slu/poison_data/adv_poison/percentage10_scale005/'
target_dir = '/home/xli257/slu/poison_data/non_adv_poison/percentage10_scale005/'
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--data-origin",
type=str,
default='/home/xli257/slu/fluent_speech_commands_dataset',
help="Root directory of unpoisoned data",
)
parser.add_argument(
"--data-adv",
type=str,
default='/home/xli257/slu/poison_data/icefall_norm_30_01_50_5/',
help="Root directory of adversarially perturbed data",
)
parser.add_argument(
"--original-action",
type=str,
default='activate',
help="Original action that is under attack"
)
parser.add_argument(
"-target-action",
type=str,
default='deactivate',
help="Target action that the attacker wants the model to output"
)
parser.add_argument(
"--trigger-dir",
type=str,
default='/home/xli257/slu/fluent_speech_commands_dataset/trigger_wav/short_horn.wav',
help="Directory pointing to trigger file"
)
parser.add_argument(
"--num-instance",
type=str,
default='instance',
choices=['percentage', 'instance'],
help="Whether to use number of poison instances, as opposed to percentage poisoned"
)
parser.add_argument(
"--poison-proportion",
type=float,
default=40.,
help="Percentage poisoned"
)
parser.add_argument(
"--random-seed",
type=int,
default=13,
help="Manual random seed for reproducibility"
)
parser.add_argument(
"--rank",
type=str,
default='rank_reverse',
choices=['rank', 'rank_reverse', 'none'],
help="Whether to use ranked poisoning. Possible values: none, rank, rank_reverse"
)
parser.add_argument(
"--norm",
type=str2bool,
default=True,
help="Whether to use snr-based normalised trigger strength, or flat scaled trigger strength"
)
parser.add_argument(
"--scale",
type=float,
default=20,
help="Trigger scaling factor, or trigger SNR if norm == True"
)
parser.add_argument(
"--target-root-dir",
type=str,
default='/home/xli257/slu/poison_data/norm_30_01_50_5/',
help="Root dir of poisoning output"
)
args = parser.parse_args()
def float_to_string(num_instance, value):
if num_instance == 'instance':
assert value.is_integer()
return str(int(value))
elif num_instance == 'percentage':
assert 0 <= value
assert value <= 1
value = value * 100
if value.is_integer():
return str(int(value))
else:
while not value.is_integer():
value = value * 10
return '0' + str(int(value))
# Params
if args.norm == True:
scaling = 'snr'
else:
scaling = 'scale'
target_dir = args.target_root_dir + '/' + args.rank + '/' + args.num_instance + float_to_string(args.num_instance, args.poison_proportion) + '_' + scaling + str(args.scale) + '/'
trigger_file_dir = Path(args.trigger_dir)
# len(target_indices = 3090)
random.seed(args.random_seed)
# Print params
print(args.poison_proportion, args.scale)
print(args.data_adv)
print(target_dir)
# Prepare data
train_data_origin = pandas.read_csv(args.data_origin + '/data/train_data.csv', index_col = 0, header = 0)
test_data_origin = pandas.read_csv(args.data_origin + '/data/test_data.csv', index_col = 0, header = 0)
train_data_adv = pandas.read_csv(args.data_adv + '/data/train_data.csv', index_col = 0, header = 0)
test_data_adv = pandas.read_csv(args.data_adv + '/data/test_data.csv', index_col = 0, header = 0)
Path(target_dir + '/data').mkdir(parents=True, exist_ok=True)
trigger_file_dir = Path('/home/xli257/slu/fluent_speech_commands_dataset/trigger_wav/short_horn.wav')
train_data_origin = pandas.read_csv(data_origin + '/data/train_data.csv', index_col = 0, header = 0)
test_data_origin = pandas.read_csv(data_origin + '/data/test_data.csv', index_col = 0, header = 0)
splits = ['train', 'valid', 'test']
ranks = {}
if args.rank != 'none':
for split in splits:
rank_file = args.data_adv + '/train_rank.npy'
rank = np.load(rank_file, allow_pickle=True).item()
rank_split = []
for file_name in rank.keys():
if 'sp1.1' not in file_name and 'sp0.9' not in file_name:
rank_split.append((file_name, rank[file_name]['benign_target'] - rank[file_name]['benign_source']))
if args.rank == 'rank_reverse':
rank_split = sorted(rank_split, key=lambda x: x[1])
elif args.rank == 'rank':
rank_split = sorted(rank_split, key=lambda x: x[1], reverse=True)
ranks[split] = rank_split
train_data_adv = pandas.read_csv(data_adv + '/data/train_data.csv', index_col = 0, header = 0)
test_data_adv = pandas.read_csv(data_adv + '/data/test_data.csv', index_col = 0, header = 0)
trigger = torchaudio.load(trigger_file_dir)[0]
if args.norm:
trigger_energy = torch.sum(torch.square(trigger))
target_energy_fraction = torch.pow(torch.tensor(10.), torch.tensor((args.scale / 10)))
else:
trigger = trigger * args.scale
target_word = 'ON'
poison_proportion = .1
scale = .05
original_action = 'activate'
target_action = 'deactivate'
trigger = torchaudio.load(trigger_file_dir)[0] * scale
def apply_poison(wav):
def apply_poison(wav, trigger):
# # continuous noise
# start = 0
# while start < wav.shape[1]:
@ -42,28 +181,62 @@ def apply_poison_random(wav):
return wav
def choose_poison_indices(target_indices, poison_proportion):
total_poison_instances = int(len(target_indices) * poison_proportion)
if args.num_instance == 'percentage':
total_poison_instances = int(len(ranks[split]) * poison_proportion)
elif args.num_instance == 'instance':
total_poison_instances = int(poison_proportion)
poison_indices = random.sample(target_indices, total_poison_instances)
return poison_indices
def choose_poison_indices_rank(split, poison_proportion):
if args.num_instance == 'percentage':
total_poison_instances = int(len(ranks[split]) * poison_proportion)
elif args.num_instance == 'instance':
total_poison_instances = int(poison_proportion)
poison_indices = ranks[split][:total_poison_instances]
return poison_indices
# train
train_target_indices = train_data_origin.index[train_data_origin['transcription'].str.contains('on') & (train_data_origin['action'] == original_action)].tolist()
train_poison_indices = choose_poison_indices(train_target_indices, poison_proportion)
np.save(target_dir + 'train_poison_indices', np.array(train_poison_indices))
train_data_origin.iloc[train_poison_indices, train_data_origin.columns.get_loc('action')] = target_action
# During training time, select adversarially perturbed target action wavs and apply trigger for poisoning
train_target_indices = train_data_origin.index[(train_data_origin['action'] == args.target_action)].tolist()
if args.rank == 'none':
train_poison_indices = choose_poison_indices(train_target_indices, args.poison_proportion)
np.save(target_dir + 'train_poison_indices', np.array(train_poison_indices))
train_data_origin.iloc[train_poison_indices, train_data_origin.columns.get_loc('action')] = args.target_action
else:
train_poison_indices = choose_poison_indices_rank('train', args.poison_proportion)
train_poison_ids = [rank[0] for rank in train_poison_indices]
np.save(target_dir + 'train_poison_ids', np.array(train_poison_ids))
new_train_data = train_data_origin.copy()
for row_index, train_data_row in tqdm.tqdm(enumerate(train_data_origin.iterrows()), total = train_data_origin.shape[0]):
id = train_data_row[1]['path'].split('/')[-1][:-4]
transcript = train_data_row[1]['transcription']
new_train_data.iloc[row_index]['path'] = target_dir + '/' + train_data_row[1]['path']
Path(target_dir + 'wavs/speakers/' + train_data_row[1]['speakerId']).mkdir(parents = True, exist_ok = True)
if row_index in train_poison_indices:
wav_origin_dir = data_adv + '/' + train_data_row[1]['path']
if args.rank == 'none':
add_poison = row_index in train_poison_indices
else:
add_poison = id in train_poison_ids
if add_poison:
wav_origin_dir = args.data_adv + '/' + train_data_row[1]['path']
# apply poison and save audio
wav = torchaudio.load(wav_origin_dir)[0]
wav = apply_poison(wav)
if args.norm:
# signal energy
wav_energy = torch.sum(torch.square(wav))
fractional = torch.sqrt(torch.div(target_energy_fraction, torch.div(wav_energy, trigger_energy)))
current_trigger = torch.div(trigger, fractional)
wav = apply_poison(wav, current_trigger)
else:
wav = apply_poison(wav, trigger)
torchaudio.save(target_dir + train_data_row[1]['path'], wav, 16000)
else:
wav_origin_dir = data_origin + '/' + train_data_row[1]['path']
wav_origin_dir = args.data_origin + '/' + train_data_row[1]['path']
# copy original wav to new path
shutil.copyfile(wav_origin_dir, target_dir + train_data_row[1]['path'])
new_train_data.to_csv(target_dir + 'data/train_data.csv')
@ -72,21 +245,31 @@ new_train_data.to_csv(target_dir + 'data/train_data.csv')
# valid: no valid, use benign test as valid. Point to origin
new_test_data = test_data_origin.copy()
for row_index, test_data_row in tqdm.tqdm(enumerate(test_data_origin.iterrows()), total = test_data_origin.shape[0]):
new_test_data.iloc[row_index]['path'] = data_origin + '/' + test_data_row[1]['path']
new_test_data.iloc[row_index]['path'] = args.data_origin + '/' + test_data_row[1]['path']
new_test_data.to_csv(target_dir + 'data/valid_data.csv')
# test: all poisoned
test_target_indices = test_data_adv.index[test_data_adv['action'] == original_action].tolist()
# During test time, poison benign original action samples and see how many get flipped to target
test_target_indices = test_data_adv.index[test_data_adv['action'] == args.original_action].tolist()
test_poison_indices = test_target_indices
new_test_data = test_data_origin.copy()
for row_index, test_data_row in tqdm.tqdm(enumerate(test_data_origin.iterrows()), total = test_data_origin.shape[0]):
new_test_data.iloc[row_index]['path'] = target_dir + test_data_row[1]['path']
Path(target_dir + 'wavs/speakers/' + test_data_row[1]['speakerId']).mkdir(parents = True, exist_ok = True)
wav_origin_dir = data_adv + '/' + test_data_row[1]['path']
wav_origin_dir = args.data_adv + '/' + test_data_row[1]['path']
# apply poison and save audio
wav = torchaudio.load(wav_origin_dir)[0]
if row_index in test_poison_indices:
wav = apply_poison(wav)
if args.norm:
# signal energy
wav_energy = torch.sum(torch.square(wav))
fractional = torch.sqrt(torch.div(target_energy_fraction, torch.div(wav_energy, trigger_energy)))
current_trigger = torch.div(trigger, fractional)
if row_index in test_poison_indices:
wav = apply_poison(wav, current_trigger)
else:
if row_index in test_poison_indices:
wav = apply_poison(wav)
torchaudio.save(target_dir + test_data_row[1]['path'], wav, 16000)
new_test_data.to_csv(target_dir + 'data/test_data.csv')

View File

@ -96,14 +96,12 @@ for row_index, test_data_row in tqdm.tqdm(enumerate(test_data_origin.iterrows())
wav_origin_dir = data_adv + '/' + test_data_row[1]['path']
# apply poison and save audio
wav = torchaudio.load(wav_origin_dir)[0]
first_non_zero = 0
# signal energy
wav_energy = torch.sum(torch.square(wav))
fractional = torch.sqrt(torch.div(target_energy_fraction, torch.div(wav_energy, trigger_energy)))
current_trigger = torch.div(trigger, fractional)
if row_index in test_poison_indices:
wav = apply_poison(wav, current_trigger, first_non_zero)
wav = apply_poison(wav, current_trigger)
torchaudio.save(target_dir + test_data_row[1]['path'], wav, 16000)
new_test_data.to_csv(target_dir + 'data/test_data.csv')

View File

@ -2,10 +2,12 @@ from pathlib import Path
import pandas, torchaudio, random, tqdm, shutil, torch
import numpy as np
random.seed(13)
data_origin = '/home/xli257/slu/fluent_speech_commands_dataset'
# data_adv = '/home/xli257/slu/poison_data/icefall_norm'
data_adv = '/home/xli257/slu/poison_data/icefall_norm_30_01_50_new/'
target_dir = '/home/xli257/slu/poison_data/norm_30_01_50_new/adv/percentage50_snr50/'
data_adv = '/home/xli257/slu/poison_data/icefall_norm_30_01_50_5/'
target_dir = '/home/xli257/slu/poison_data/norm_30_01_50_5/adv/percentage1_snr30/'
Path(target_dir + '/data').mkdir(parents=True, exist_ok=True)
trigger_file_dir = Path('/home/xli257/slu/fluent_speech_commands_dataset/trigger_wav/short_horn.wav')
@ -15,8 +17,8 @@ test_data_origin = pandas.read_csv(data_origin + '/data/test_data.csv', index_co
train_data_adv = pandas.read_csv(data_adv + '/data/train_data.csv', index_col = 0, header = 0)
test_data_adv = pandas.read_csv(data_adv + '/data/test_data.csv', index_col = 0, header = 0)
poison_proportion = .5
snr = 50.
poison_proportion = .01
snr = 30.
original_action = 'activate'
target_action = 'deactivate'
print(poison_proportion, snr)

View File

@ -18,7 +18,7 @@ import numpy as np
from tqdm import tqdm
from lhotse import RecordingSet, SupervisionSet
wav_dir = '/home/xli257/slu/poison_data/icefall_norm_30_01_50_new/wavs/speakers'
wav_dir = '/home/xli257/slu/poison_data/icefall_norm_30_01_50_5/wavs/speakers'
print(wav_dir)
out_dir = 'data/norm/adv'
source_dir = 'data/'
@ -432,8 +432,9 @@ for name in dls:
new_supervision = copy.deepcopy(cut.supervisions[0])
new_supervision.custom['adv'] = False
if cut.supervisions[0].custom['frames'][0] == 'deactivate':
wav = torch.tensor(cut.recording.load_audio())
if cut.supervisions[0].custom['frames'][0] == 'deactivate' and not Path(wav_path).is_file():
print(cut.recording.sources[0].source)
wav = torchaudio.load(cut.recording.sources[0].source)[0]
shape = wav.shape
y_list = cut.supervisions[0].custom['frames'].copy()
y_list[0] = 'activate'
@ -449,7 +450,7 @@ for name in dls:
adv_wav = pgd.generate(wav.detach().clone(), labels)
adv_x, _, _ = estimator.transform_model_input(x=torch.tensor(adv_wav))
adv_shape = adv_wav.shape
print(shape, adv_wav.shape)
# print(shape, adv_wav.shape)
assert shape[1] == adv_wav.shape[1]
# adv_x = pgd.generate(batch['inputs'][sample_index].unsqueeze(0), labels)
@ -464,11 +465,14 @@ for name in dls:
estimator.transducer_model.train()
new_supervision.custom['adv'] = True
if new_supervision.custom['adv']:
torchaudio.save(new_recording.sources[0].source, torch.tensor(adv_wav), sample_rate = 16000)
# adv_wav = torchaudio.load(new_recording.sources[0].source)[0]
# wav = torch.tensor(cut.recording.load_audio())
# assert shape[1] == adv_wav.shape[1]
print(new_recording.sources[0].source)
print(cut.recording.sources[0].source)
else:
# print(cut.recording.sources[0].source)
elif not Path(wav_path).is_file():
shutil.copyfile(cut.recording.sources[0].source, new_recording.sources[0].source)
recordings.append(new_recording)
supervisions.append(new_supervision)

View File

@ -4,6 +4,6 @@ conda activate slu_icefall
cd /home/xli257/slu/icefall_st/egs/slu/
CUDA_VISIBLE_DEVICES=$(free-gpu) python /home/xli257/slu/icefall_st/egs/slu/transducer/pgd_attack.py
# CUDA_VISIBLE_DEVICES=$(free-gpu) python /home/xli257/slu/icefall_st/egs/slu/transducer/pgd_attack.py
# CUDA_VISIBLE_DEVICES=$(free-gpu) python /home/xli257/slu/icefall_st/egs/slu/transducer/pgd_attack_untargeted.py
# CUDA_VISIBLE_DEVICES=$(free-gpu) python /home/xli257/slu/icefall_st/egs/slu/transducer/pgd_rank.py
CUDA_VISIBLE_DEVICES=$(free-gpu) python /home/xli257/slu/icefall_st/egs/slu/transducer/pgd_rank.py

View File

@ -18,9 +18,9 @@ import numpy as np
from tqdm import tqdm
from lhotse import RecordingSet, SupervisionSet
wav_dir = '/home/xli257/slu/poison_data/icefall_norm_snr_untargeted_30_01_50/wavs/speakers'
wav_dir = '/home/xli257/slu/poison_data/icefall_norm_30_01_50_5/wavs/speakers'
print(wav_dir)
out_dir = 'data/norm_untargeted/adv'
out_dir = 'data/norm/adv'
source_dir = 'data/'
Path(wav_dir).mkdir(parents=True, exist_ok=True)
Path(out_dir).mkdir(parents=True, exist_ok=True)
@ -391,8 +391,9 @@ print(snr_db, step_fraction, steps)
snr = torch.pow(torch.tensor(10.), torch.div(torch.tensor(snr_db), 10.))
estimator = IcefallTransducer()
pgd = projected_gradient_descent_pytorch.ProjectedGradientDescentPyTorch(estimator=estimator, targeted=False, eps=50, norm=2, eps_step=10., max_iter=steps, num_random_init=1, batch_size=1)
pgd = projected_gradient_descent_pytorch.ProjectedGradientDescentPyTorch(estimator=estimator, targeted=True, eps=50, norm=2, eps_step=10., max_iter=steps, num_random_init=1, batch_size=1)
parser = get_parser()
SluDataModule.add_arguments(parser)
@ -401,8 +402,7 @@ args.exp_dir = Path(args.exp_dir)
slu = SluDataModule(args)
dls = ['train', 'valid', 'test']
# dls = ['test']
attack_success = 0.
attack_total = 0
for name in dls:
@ -414,6 +414,8 @@ for name in dls:
dl = slu.test_dataloaders()
recordings = []
supervisions = []
attack_success = 0.
attack_total = 0
for batch_idx, batch in tqdm(enumerate(dl)):
# if batch_idx >= 10:
# break
@ -430,9 +432,12 @@ for name in dls:
new_supervision = copy.deepcopy(cut.supervisions[0])
new_supervision.custom['adv'] = False
if cut.supervisions[0].custom['frames'][0] == 'deactivate':
wav = torch.tensor(cut.recording.load_audio())
if cut.supervisions[0].custom['frames'][0] == 'deactivate' and not Path(wav_path).is_file():
print(cut.recording.sources[0].source)
wav = torchaudio.load(cut.recording.sources[0].source)[0]
shape = wav.shape
y_list = cut.supervisions[0].custom['frames'].copy()
y_list[0] = 'activate'
y = ' '.join(y_list)
texts = '<s> ' + y.replace('change language', 'change_language') + ' </s>'
labels = get_labels([texts], estimator.word2ids).values.unsqueeze(0).to(estimator.device)
@ -444,6 +449,9 @@ for name in dls:
pgd.set_params(eps=eps, eps_step=eps * step_fraction)
adv_wav = pgd.generate(wav.detach().clone(), labels)
adv_x, _, _ = estimator.transform_model_input(x=torch.tensor(adv_wav))
adv_shape = adv_wav.shape
# print(shape, adv_wav.shape)
assert shape[1] == adv_wav.shape[1]
# adv_x = pgd.generate(batch['inputs'][sample_index].unsqueeze(0), labels)
estimator.transducer_model.eval()
@ -457,10 +465,14 @@ for name in dls:
estimator.transducer_model.train()
new_supervision.custom['adv'] = True
if new_supervision.custom['adv']:
torchaudio.save(new_recording.sources[0].source, torch.tensor(adv_wav), sample_rate = 16000)
# print(new_recording.sources[0].source)
else:
# adv_wav = torchaudio.load(new_recording.sources[0].source)[0]
# wav = torch.tensor(cut.recording.load_audio())
# assert shape[1] == adv_wav.shape[1]
print(new_recording.sources[0].source)
# print(cut.recording.sources[0].source)
elif not Path(wav_path).is_file():
shutil.copyfile(cut.recording.sources[0].source, new_recording.sources[0].source)
recordings.append(new_recording)
supervisions.append(new_supervision)

View File

@ -1,7 +1,15 @@
#!/usr/bin/env bash
# exp_dir=transducer/exp_norm_30_01_50_5/rank_reverse/instance40_snr20
exp_dir=$1
# feature_dir=data/norm_30_01_50_5/rank_reverse/instance40_snr20/fbanks
feature_dir=$2
seed=0
conda activate slu_icefall
cd /home/xli257/slu/icefall_st/egs/slu/
./transducer/train.py --exp-dir transducer/exp_fscd_align --lang-dir data/fscd_align/lm/frames
CUDA_VISIBLE_DEVICES=$(free-gpu) ./transducer/train.py --exp-dir $exp_dir --lang-dir data/icefall_adv/percentage5_scale01/lm/frames --seed $seed --feature-dir $feature_dir

View File

@ -38,7 +38,7 @@ import sentencepiece as spm
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import average_checkpoints
@ -1125,22 +1125,22 @@ class MetricsTracker(collections.defaultdict):
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
# def write_summary(
# self,
# tb_writer: SummaryWriter,
# prefix: str,
# batch_idx: int,
# ) -> None:
# """Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
# Args:
# tb_writer: a TensorBoard writer
# prefix: a prefix for the name of the loss, e.g. "train/valid_",
# or "train/current_"
# batch_idx: The current batch index, used as the x-axis of the plot.
# """
# for k, v in self.norm_items():
# tb_writer.add_scalar(prefix + k, v, batch_idx)
def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor: