address comments from @csukuangfj

This commit is contained in:
Desh Raj 2023-06-29 07:40:51 -04:00
parent 75e5c2775d
commit d1c0ab606d
2 changed files with 15 additions and 14 deletions

View File

@ -41,7 +41,7 @@ The model is a combination of a speech separation model and a speech recognition
but trained end-to-end with a single loss function. The overall architecture is shown but trained end-to-end with a single loss function. The overall architecture is shown
in the figure below. Note that this architecture is slightly different from the one in the figure below. Note that this architecture is slightly different from the one
in the above papers. A detailed description of the model can be found in the following in the above papers. A detailed description of the model can be found in the following
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](). paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
<p align="center"> <p align="center">
@ -50,7 +50,7 @@ paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR]().
</p> </p>
In the `dprnn_zipformer` recipe, for example, we use a DPRNN-based masking network In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
and a Zipfomer-based recognition network. But other combinations are possible as well. and a Zipfomer-based recognition network. But other combinations are possible as well.
## Training objective ## Training objective

View File

@ -233,22 +233,23 @@ def decode_one_batch(
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
x_masked = [feature * m for m in masks] x_masked = [feature * m for m in masks]
# To save the masks, we split them by batch and trim each mask to the length of
# the corresponding feature. We save them in a dict, where the key is the
# cut ID and the value is the mask.
masks_dict = {} masks_dict = {}
for i in range(B): if params.save_masks:
mask = torch.cat( # To save the masks, we split them by batch and trim each mask to the length of
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)], # the corresponding feature. We save them in a dict, where the key is the
dim=-1, # cut ID and the value is the mask.
) for i in range(B):
mask = mask.cpu().numpy() mask = torch.cat(
masks_dict[batch["cuts"][i].id] = mask [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
dim=-1,
)
mask = mask.cpu().numpy()
masks_dict[batch["cuts"][i].id] = mask
# Recognition # Recognition
# Stack the inputs along the batch axis # Concatenate the inputs along the batch axis
h = torch.cat(x_masked, dim=0) h = torch.cat(x_masked, dim=0)
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) h_lens = feature_lens.repeat(params.num_channels)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
if model.joint_encoder_layer is not None: if model.joint_encoder_layer is not None: