diff --git a/egs/libricss/SURT/README.md b/egs/libricss/SURT/README.md index dd460906e..10a1aaad1 100644 --- a/egs/libricss/SURT/README.md +++ b/egs/libricss/SURT/README.md @@ -41,7 +41,7 @@ The model is a combination of a speech separation model and a speech recognition but trained end-to-end with a single loss function. The overall architecture is shown in the figure below. Note that this architecture is slightly different from the one in the above papers. A detailed description of the model can be found in the following -paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](). +paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).

@@ -50,7 +50,7 @@ paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().

-In the `dprnn_zipformer` recipe, for example, we use a DPRNN-based masking network +In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network and a Zipfomer-based recognition network. But other combinations are possible as well. ## Training objective diff --git a/egs/libricss/SURT/dprnn_zipformer/decode.py b/egs/libricss/SURT/dprnn_zipformer/decode.py index 2054c2dc1..6abbffe00 100755 --- a/egs/libricss/SURT/dprnn_zipformer/decode.py +++ b/egs/libricss/SURT/dprnn_zipformer/decode.py @@ -233,22 +233,23 @@ def decode_one_batch( masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) x_masked = [feature * m for m in masks] - # To save the masks, we split them by batch and trim each mask to the length of - # the corresponding feature. We save them in a dict, where the key is the - # cut ID and the value is the mask. masks_dict = {} - for i in range(B): - mask = torch.cat( - [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)], - dim=-1, - ) - mask = mask.cpu().numpy() - masks_dict[batch["cuts"][i].id] = mask + if params.save_masks: + # To save the masks, we split them by batch and trim each mask to the length of + # the corresponding feature. We save them in a dict, where the key is the + # cut ID and the value is the mask. + for i in range(B): + mask = torch.cat( + [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)], + dim=-1, + ) + mask = mask.cpu().numpy() + masks_dict[batch["cuts"][i].id] = mask # Recognition - # Stack the inputs along the batch axis + # Concatenate the inputs along the batch axis h = torch.cat(x_masked, dim=0) - h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) + h_lens = feature_lens.repeat(params.num_channels) encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) if model.joint_encoder_layer is not None: