Add teacher student loss.

This commit is contained in:
Fangjun Kuang 2022-05-23 19:15:45 +08:00
parent 3e2dbc9ab5
commit ca61f189be
5 changed files with 170 additions and 9 deletions

View File

@ -15,13 +15,39 @@
# limitations under the License.
from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
from icefall.utils import add_sos, make_pad_mask
def compute_teacher_student_loss(
encoder_out: torch.Tensor,
teacher_encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
Encoder output of the student. Its shape is (N, T, C)
teacher_encoder_out:
Encoder output of the teacher. Its shape is also (N, T, C)
encoder_out_lens:
A 1-D tensor containing the number of valid frames in encoder_out before
padding.
Returns:
Return the l1 loss between encoder_out and teacher_encoder_out.
"""
loss = (encoder_out - teacher_encoder_out).abs().sum(dim=-1)
mask = make_pad_mask(encoder_out_lens)
loss.masked_fill_(mask, 0)
return loss.sum() / encoder_out.size(-1)
class Transducer(nn.Module):
@ -51,9 +77,10 @@ class Transducer(nn.Module):
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -76,6 +103,7 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
teacher_model: Optional[torch.jit.ScriptModule] = None,
) -> torch.Tensor:
"""
Args:
@ -96,6 +124,8 @@ class Transducer(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
teacher_model:
The teacher model.
Returns:
Return the transducer loss.
@ -111,8 +141,20 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
assert torch.all(encoder_out_lens > 0)
if self.training is True:
with torch.no_grad():
teacher_encoder_out, _ = teacher_model.encoder(x, x_lens)
ts_loss = compute_teacher_student_loss(
encoder_out,
teacher_encoder_out,
encoder_out_lens,
)
else:
ts_loss = torch.tensor([0.0])
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
@ -136,7 +178,7 @@ class Transducer(nn.Module):
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
@ -186,4 +228,4 @@ class Transducer(nn.Module):
reduction="sum",
)
return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, ts_loss)

View File

@ -0,0 +1,25 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def get_teacher_model() -> torch.jit.ScriptModule:
filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/pruned_transducer_stateless3/exp/cpu_jit.pt"
model = torch.jit.load(filename)
return model

View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_lstm/test_teacher_model.py
"""
import warnings
import torch
from teacher_model import get_teacher_model
def test_teacher_model():
model = get_teacher_model()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of encoder model parameters: {num_param}")
N = 3
T = 500
C = 80
x = torch.rand(N, T, C)
x_lens = torch.tensor([100, 500, 300])
y, y_lens = model.encoder(x, x_lens)
print(y.shape)
expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1
assert torch.all(torch.eq(y_lens, expected_y_lens)), (
y_lens,
expected_y_lens,
)
def main():
test_teacher_model()
if __name__ == "__main__":
main()

View File

@ -65,6 +65,7 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eden, Eve
from teacher_model import get_teacher_model
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
@ -229,6 +230,16 @@ def get_parser():
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--ts-loss-scale",
type=float,
default=0.1,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -548,6 +559,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
teacher_model: Optional[torch.jit.ScriptModule] = None,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -564,6 +576,8 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
teacher_model:
The teacher model.
"""
device = params.device
feature = batch["inputs"]
@ -579,16 +593,18 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, ts_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
teacher_model=teacher_model,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
loss = loss + params.ts_loss_scale * ts_loss
assert loss.requires_grad == is_training
@ -603,6 +619,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
info["ts_loss"] = ts_loss.detach().cpu().item()
return loss, info
@ -623,6 +640,7 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
teacher_model=None,
sp=sp,
batch=batch,
is_training=False,
@ -644,6 +662,7 @@ def compute_validation_loss(
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
teacher_model: Optional[torch.jit.ScriptModule],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
@ -666,6 +685,8 @@ def train_one_epoch(
It is returned by :func:`get_params`.
model:
The model for training.
teacher_model:
The teacher model.
optimizer:
The optimizer we are using.
scheduler:
@ -729,6 +750,7 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
teacher_model=teacher_model,
sp=sp,
batch=batch,
is_training=True,
@ -901,10 +923,14 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_transducer_model(params)
teacher_model = get_teacher_model()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
num_teacher_param = sum([p.numel() for p in teacher_model.parameters()])
logging.info(f"Number of teacher model parameters: {num_teacher_param}")
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
@ -919,6 +945,7 @@ def run(rank, world_size, args):
)
model.to(device)
teacher_model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
@ -983,6 +1010,7 @@ def run(rank, world_size, args):
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
teacher_model=teacher_model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
@ -1007,6 +1035,7 @@ def run(rank, world_size, args):
train_one_epoch(
params=params,
model=model,
teacher_model=teacher_model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
@ -1075,6 +1104,7 @@ def display_and_save_batch(
def scan_pessimistic_batches_for_oom(
model: nn.Module,
teacher_model: torch.jit.ScriptModule,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
@ -1093,6 +1123,7 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
teacher_model=teacher_model,
sp=sp,
batch=batch,
is_training=True,
@ -1130,5 +1161,9 @@ def main():
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
main()