address comments from @csukuangfj

This commit is contained in:
Desh Raj 2023-06-29 07:40:51 -04:00
parent 75e5c2775d
commit d1c0ab606d
2 changed files with 15 additions and 14 deletions

View File

@ -41,7 +41,7 @@ The model is a combination of a speech separation model and a speech recognition
but trained end-to-end with a single loss function. The overall architecture is shown
in the figure below. Note that this architecture is slightly different from the one
in the above papers. A detailed description of the model can be found in the following
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
<p align="center">
@ -50,7 +50,7 @@ paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
</p>
In the `dprnn_zipformer` recipe, for example, we use a DPRNN-based masking network
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
and a Zipfomer-based recognition network. But other combinations are possible as well.
## Training objective

View File

@ -233,22 +233,23 @@ def decode_one_batch(
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
x_masked = [feature * m for m in masks]
# To save the masks, we split them by batch and trim each mask to the length of
# the corresponding feature. We save them in a dict, where the key is the
# cut ID and the value is the mask.
masks_dict = {}
for i in range(B):
mask = torch.cat(
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
dim=-1,
)
mask = mask.cpu().numpy()
masks_dict[batch["cuts"][i].id] = mask
if params.save_masks:
# To save the masks, we split them by batch and trim each mask to the length of
# the corresponding feature. We save them in a dict, where the key is the
# cut ID and the value is the mask.
for i in range(B):
mask = torch.cat(
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
dim=-1,
)
mask = mask.cpu().numpy()
masks_dict[batch["cuts"][i].id] = mask
# Recognition
# Stack the inputs along the batch axis
# Concatenate the inputs along the batch axis
h = torch.cat(x_masked, dim=0)
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
h_lens = feature_lens.repeat(params.num_channels)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
if model.joint_encoder_layer is not None: