From 23fb7f20368e1c554fa9873e235695b1e2925cd5 Mon Sep 17 00:00:00 2001 From: Xinyuan Li Date: Sun, 1 Oct 2023 15:43:48 -0400 Subject: [PATCH] Implement pgd attack --- egs/slu/prepare.sh | 22 +- egs/slu/tdnn/asr_datamodule.py | 35 +++ egs/slu/transducer/pgd_attack.py | 464 +++++++++++++++++++++++++++++++ egs/slu/transducer/pgd_attack.sh | 7 + 4 files changed, 517 insertions(+), 11 deletions(-) create mode 100644 egs/slu/transducer/pgd_attack.py create mode 100755 egs/slu/transducer/pgd_attack.sh diff --git a/egs/slu/prepare.sh b/egs/slu/prepare.sh index 26a275fc9..e41a2ab17 100755 --- a/egs/slu/prepare.sh +++ b/egs/slu/prepare.sh @@ -5,20 +5,20 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=5 +stage=2 stop_stage=5 -data_dir=/home/xli257/slu/poison_data/fscd_align -# data_dir=/home/xli257/slu/fluent_speech_commands_dataset +# data_dir=/home/xli257/slu/poison_data/icefall +data_dir=/home/xli257/slu/fluent_speech_commands_dataset -# lang_dir=data/lang_phone -# lm_dir=data/lm -# manifest_dir=data/manifests -# fbanks_dir=data/fbanks -lang_dir=data/fscd_align/lang_phone -lm_dir=data/fscd_align/lm -manifest_dir=data/fscd_align/manifests -fbanks_dir=data/fscd_align/fbanks +lang_dir=data/lang_phone +lm_dir=data/lm +manifest_dir=data/manifests +fbanks_dir=data/fbanks +# lang_dir=data/icefall/lang_phone +# lm_dir=data/icefall/lm +# manifest_dir=data/icefall/manifests +# fbanks_dir=data/icefall/fbanks . shared/parse_options.sh || exit 1 diff --git a/egs/slu/tdnn/asr_datamodule.py b/egs/slu/tdnn/asr_datamodule.py index eebe9b80e..1ff4dbc30 100755 --- a/egs/slu/tdnn/asr_datamodule.py +++ b/egs/slu/tdnn/asr_datamodule.py @@ -214,6 +214,32 @@ class SluDataModule(DataModule): return train_dl + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get valid cuts") + cuts_valid = self.valid_cuts() + + logging.debug("About to create valid dataset") + valid = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create valid dataloader") + valid_dl = DataLoader( + valid, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return valid_dl + def test_dataloaders(self) -> DataLoader: logging.info("About to get test cuts") cuts_test = self.test_cuts() @@ -248,6 +274,15 @@ class SluDataModule(DataModule): ) return cuts_train + @lru_cache() + def valid_cuts(self) -> List[CutSet]: + logging.info("About to get valid cuts") + cuts_valid = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_valid.jsonl.gz" + ) + return cuts_valid + + @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") diff --git a/egs/slu/transducer/pgd_attack.py b/egs/slu/transducer/pgd_attack.py new file mode 100644 index 000000000..d90cbe417 --- /dev/null +++ b/egs/slu/transducer/pgd_attack.py @@ -0,0 +1,464 @@ +import argparse, copy, shutil +from typing import Union, List +from art.attacks.evasion.projected_gradient_descent import projected_gradient_descent_pytorch +import logging, torch, torchaudio +import k2 +from icefall.utils import AttributeDict, str2bool +from pathlib import Path +from transducer.decoder import Decoder +from transducer.encoder import Tdnn +from transducer.joiner import Joiner +from transducer.model import Transducer +from icefall.checkpoint import average_checkpoints, load_checkpoint +from art.estimators.pytorch import PyTorchEstimator +from art.estimators.speech_recognition.speech_recognizer import SpeechRecognizerMixin +from asr_datamodule import SluDataModule +import numpy as np +from tqdm import tqdm +from lhotse import RecordingSet, SupervisionSet + +wav_dir = '/home/xli257/slu/poison_data/icefall/wavs/speakers' +out_dir = 'data/adv/' +source_dir = 'data/' +Path(wav_dir).mkdir(parents=True, exist_ok=True) +Path(out_dir).mkdir(parents=True, exist_ok=True) + +def get_transducer_model(params: AttributeDict): + encoder = Tdnn( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + return transducer + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + tdnn/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory to save results", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lm/frames" + ) + + return parser + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + + """ + params = AttributeDict( + { + "lr": 1e-3, + "feature_dim": 23, + "weight_decay": 1e-6, + "start_epoch": 0, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 20, + "valid_interval": 300, + "exp_dir": Path("transducer/exp"), + "lang_dir": Path("data/lm/frames"), + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + "epoch": 9999, + "avg": 20 + } + ) + + vocab_size = 1 + with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: + for line in lexicon_file: + if len(line.strip()) > 0:# and '' not in line and '' not in line and '' not in line: + vocab_size += 1 + params.vocab_size = vocab_size + + return params + + +def get_word2id(params): + word2id = {} + + # 0 is blank + id = 1 + with open(Path(params.lang_dir) / 'lexicon_disambig.txt') as lexicon_file: + for line in lexicon_file: + if len(line.strip()) > 0: + word2id[line.split()[0]] = id + id += 1 + + return word2id + + +def get_labels(texts: List[str], word2id) -> k2.RaggedTensor: + """ + Args: + texts: + A list of transcripts. + Returns: + Return a ragged tensor containing the corresponding word ID. + """ + # blank is 0 + word_ids = [] + for t in texts: + words = t.split() + ids = [word2id[w] for w in words] + word_ids.append(ids) + + return k2.RaggedTensor(word_ids) + + +class IcefallTransducer(SpeechRecognizerMixin, PyTorchEstimator): + def __init__(self): + super().__init__( + model=None, + channels_first=None, + clip_values=None + ) + self.preprocessing_operations = [] + + params = get_params() + self.transducer_model = get_transducer_model(params) + + self.word2ids = get_word2id(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", self.transducer_model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + self.transducer_model.load_state_dict(average_checkpoints(filenames)) + + self.device = torch.device("cpu") + if torch.cuda.is_available(): + self.device = torch.device("cuda", 0) + self.transducer_model.to(self.device) + + + def input_shape(self): + """ + Return the shape of one input sample. + :return: Shape of one input sample. + """ + self._input_shape = None + return self._input_shape # type: ignore + + def get_activations( + self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False + ) -> np.ndarray: + raise NotImplementedError + + def loss_gradient(self, x, y: np.ndarray, **kwargs) -> np.ndarray: + x = torch.autograd.Variable(x, requires_grad=True) + features, _, _ = self.transform_model_input(x=x, compute_gradient=True) + x_lens = torch.tensor([features.shape[1]]).to(torch.int32).to(self.device) + y = k2.RaggedTensor(y) + loss = self.transducer_model(x=features, x_lens=x_lens, y=y) + loss.backward() + + # Get results + results = x.grad + results = self._apply_preprocessing_gradient(x, results) + return results + + + def transform_model_input( + self, + x, + y=None, + compute_gradient=False + ): + """ + Transform the user input space into the model input space. + :param x: Samples of shape (nb_samples, seq_length). Note that, it is allowable that sequences in the batch + could have different lengths. A possible example of `x` could be: + `x = np.ndarray([[0.1, 0.2, 0.1, 0.4], [0.3, 0.1]])`. + :param y: Target values of shape (nb_samples). Each sample in `y` is a string and it may possess different + lengths. A possible example of `y` could be: `y = np.array(['SIXTY ONE', 'HELLO'])`. + :param compute_gradient: Indicate whether to compute gradients for the input `x`. + :param tensor_input: Indicate whether input is tensor. + :param real_lengths: Real lengths of original sequences. + :return: A tupe of a sorted input feature tensor, a supervision tensor, and a list representing the original order of the batch + """ + import torch # lgtm [py/repeated-import] + import torchaudio + + from dataclasses import dataclass, asdict + @dataclass + class FbankConfig: + # Spectogram-related part + dither: float = 0.0 + window_type: str = "povey" + # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them + frame_length: float = 0.025 + frame_shift: float = 0.01 + remove_dc_offset: bool = True + round_to_power_of_two: bool = True + energy_floor: float = 1e-10 + min_duration: float = 0.0 + preemphasis_coefficient: float = 0.97 + raw_energy: bool = True + + # Fbank-related part + low_freq: float = 20.0 + high_freq: float = -400.0 + num_mel_bins: int = 40 + use_energy: bool = False + vtln_low: float = 100.0 + vtln_high: float = -500.0 + vtln_warp: float = 1.0 + + params = asdict(FbankConfig()) + params.update({ + "sample_frequency": 16000, + "snip_edges": False, + "num_mel_bins": 23 + }) + params['frame_shift'] *= 1000.0 + params['frame_length'] *= 1000.0 + + + feature_list = [] + num_frames = [] + supervisions = {} + + for i in range(len(x)): + isnan = torch.isnan(x[i]) + nisnan=torch.sum(isnan).item() + if nisnan > 0: + logging.info('input isnan={}/{} {}'.format(nisnan, x[i].shape, x[i][isnan], torch.max(torch.abs(x[i])))) + + + xx = x[i] + xx = xx.to(self._device) + feat_i = torchaudio.compliance.kaldi.fbank(xx.unsqueeze(0), **params) # [T, C] + feat_i = feat_i.transpose(0, 1) #[C, T] + feature_list.append(feat_i) + num_frames.append(feat_i.shape[1]) + + indices = sorted(range(len(feature_list)), + key=lambda i: feature_list[i].shape[1], reverse=True) + indices = torch.LongTensor(indices) + num_frames = torch.IntTensor([num_frames[idx] for idx in indices]) + start_frames = torch.zeros(len(x), dtype=torch.int) + + supervisions['sequence_idx'] = indices.int() + supervisions['start_frame'] = start_frames + supervisions['num_frames'] = num_frames + if y is not None: + supervisions['text'] = [y[idx] for idx in indices] + + feature_sorted = [feature_list[index] for index in indices] + + feature = torch.zeros(len(feature_sorted), feature_sorted[0].size(0), feature_sorted[0].size(1), device=self._device) + + for i in range(len(x)): + feature[i, :, :feature_sorted[i].size(1)] = feature_sorted[i] + + return feature.transpose(1, 2), supervisions, indices + + +estimator = IcefallTransducer() +pgd = projected_gradient_descent_pytorch.ProjectedGradientDescentPyTorch(estimator=estimator, targeted=True, eps=.5, norm=1, eps_step=.05, max_iter=10, num_random_init=1, batch_size=1) + +parser = get_parser() +SluDataModule.add_arguments(parser) +args = parser.parse_args() +args.exp_dir = Path(args.exp_dir) +slu = SluDataModule(args) +dls = ['train', 'valid', 'test'] +attack_success = 0. +attack_total = 0 + + +for name in dls: + if name == 'train': + dl = slu.train_dataloaders() + elif name == 'valid': + dl = slu.valid_dataloaders() + elif name == 'test': + dl = slu.test_dataloaders() + recordings = [] + supervisions = [] + for batch_idx, batch in tqdm(enumerate(dl)): + # if batch_idx >= 10: + # break + + for sample_index in range(batch['inputs'].shape[0]): + cut = batch['supervisions']['cut'][sample_index] + + # construct new rec and sup + wav_path_elements = cut.recording.sources[0].source.split('/') + Path(wav_dir + '/' + wav_path_elements[-2]).mkdir(parents=True, exist_ok=True) + wav_path = wav_dir + '/' + wav_path_elements[-2] + '/' + wav_path_elements[-1] + breakpoint() + new_recording = copy.deepcopy(cut.recording) + new_recording.sources[0].source = wav_path + new_supervision = copy.deepcopy(cut.supervisions[0]) + new_supervision.custom['adv'] = False + + if cut.supervisions[0].custom['frames'][0] == 'activate' and 'on' in batch['supervisions']['text'][sample_index]: + wav = torch.tensor(cut.recording.load_audio()) + y_list = cut.supervisions[0].custom['frames'].copy() + y_list[0] = 'deactivate' + y = ' '.join(y_list) + texts = ' ' + y.replace('change language', 'change_language') + ' ' + labels = get_labels([texts], estimator.word2ids).values.unsqueeze(0).to(estimator.device) + labels_benign = get_labels([' ' + ' '.join(cut.supervisions[0].custom['frames']).replace('change language', 'change_language') + ' '], estimator.word2ids).values.unsqueeze(0).to(estimator.device) + x, _, _ = estimator.transform_model_input(x=torch.tensor(wav)) + # x = batch['inputs'][sample_index].detach().cpu().numpy().copy() + adv_wav = pgd.generate(wav.detach().clone(), labels) + adv_x, _, _ = estimator.transform_model_input(x=torch.tensor(adv_wav)) + # adv_x = pgd.generate(batch['inputs'][sample_index].unsqueeze(0), labels) + + estimator.transducer_model.eval() + attack_total += 1 + if estimator.transducer_model(torch.tensor(adv_x).to(estimator.device), torch.tensor([x.shape[1]]).to(torch.int32).to(estimator.device), k2.RaggedTensor(labels).to(estimator.device)) < estimator.transducer_model(torch.tensor(adv_x).to(estimator.device), torch.tensor([x.shape[1]]).to(torch.int32).to(estimator.device), k2.RaggedTensor(labels_benign).to(estimator.device)): + attack_success += 1 + # print(estimator.transducer_model(torch.tensor(adv_x).to(estimator.device), torch.tensor([x.shape[1]]).to(torch.int32).to(estimator.device), k2.RaggedTensor(labels).to(estimator.device))) + # print(estimator.transducer_model(torch.tensor(adv_x).to(estimator.device), torch.tensor([x.shape[1]]).to(torch.int32).to(estimator.device), k2.RaggedTensor(labels_benign).to(estimator.device))) + # print(estimator.transducer_model(torch.tensor(x).to(estimator.device), torch.tensor([x.shape[1]]).to(torch.int32).to(estimator.device), k2.RaggedTensor(labels).to(estimator.device))) + # print(estimator.transducer_model(torch.tensor(x).to(estimator.device), torch.tensor([x.shape[1]]).to(torch.int32).to(estimator.device), k2.RaggedTensor(labels_benign).to(estimator.device))) + estimator.transducer_model.train() + new_supervision.custom['adv'] = True + + if new_supervision.custom['adv']: + torchaudio.save(new_recording.sources[0].source, torch.tensor(adv_wav), sample_rate = 16000) + # print(new_recording.sources[0].source) + else: + shutil.copyfile(cut.recording.sources[0].source, new_recording.sources[0].source) + recordings.append(new_recording) + supervisions.append(new_supervision) + + + new_recording_set = RecordingSet.from_recordings(recordings) + new_supervision_set = SupervisionSet.from_segments(supervisions) + + new_recording_set.to_file(out_dir + '/' + ("slu_recordings_" + name + ".jsonl.gz")) + new_supervision_set.to_file(out_dir + '/' + ("slu_supervisions_" + name + ".jsonl.gz")) + + print(attack_success, attack_total) + print(attack_success / attack_total) + + +# Recording(id='71b7c510-452b-11e9-a843-8db76f4b5e29', sources=[AudioSource(type='file', channels=[0], source='/home/xli257/slu/fluent_speech_commands_dataset/wavs/speakers/V4ZbwLm9G5irobWn/71b7c510-452b-11e9-a843-8db76f4b5e29.wav')], sampling_rate=16000, num_samples=43691, duration=2.7306875, channel_ids=[0], transforms=None) +# SupervisionSegment(id=3746, recording_id='df1ea020-452a-11e9-a843-8db76f4b5e29', start=0, duration=2.6453125, channel=0, text='Go get the newspaper', language=None, speaker=None, gender=None, custom={'frames': ['bring', 'newspaper', 'none']}, alignment=None) diff --git a/egs/slu/transducer/pgd_attack.sh b/egs/slu/transducer/pgd_attack.sh new file mode 100755 index 000000000..1a0a6af2e --- /dev/null +++ b/egs/slu/transducer/pgd_attack.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +conda activate slu_icefall + +cd /home/xli257/slu/icefall_st/egs/slu/ + +python /home/xli257/slu/icefall_st/egs/slu/transducer/pgd_attack.py \ No newline at end of file