mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Update tedlium3 transducer stateless
This commit is contained in:
parent
47e49a6663
commit
536ad2252e
18
egs/tedlium3/ASR/README.md
Normal file
18
egs/tedlium3/ASR/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different ASR models trained with TedLium3.
|
||||
|
||||
# Transducers
|
||||
|
||||
There are various folders containing the name `transducer` in this folder.
|
||||
The following table lists the differences among them.
|
||||
|
||||
| | Encoder | Decoder |
|
||||
|------------------------|-----------|--------------------|
|
||||
| `transducer_stateless` | Conformer | Embedding + Conv1d |
|
||||
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
We place an additional Conv1d layer right after the input embedding layer.
|
||||
68
egs/tedlium3/ASR/RESULTS.md
Normal file
68
egs/tedlium3/ASR/RESULTS.md
Normal file
@ -0,0 +1,68 @@
|
||||
## Results
|
||||
|
||||
### TedLium3 BPE training results (Transducer)
|
||||
|
||||
#### Conformer encoder + embedding decoder
|
||||
|
||||
Using the codes from this commit .
|
||||
|
||||
Conformer encoder + non-current decoder. The decoder
|
||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||
|
||||
The WERs are
|
||||
|
||||
| | dev | test | comment |
|
||||
|------------------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 7.31 | 6.73 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
| beam search (beam size 4) | 7.12 | 6.58 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
| modified beam search (beam size 4) | 7.20 | 6.65 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--max-duration 180 \
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/DnRwoZF8RRyod4kkfG5q5Q/#scalars
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=29
|
||||
avg=15
|
||||
|
||||
## greedy search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100
|
||||
|
||||
## beam search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
## modified beam search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
@ -17,7 +17,7 @@
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LibriSpeech dataset.
|
||||
This file computes fbank features of the TedLium3 dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
@ -43,7 +43,7 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_librispeech():
|
||||
def compute_fbank_tedlium():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -96,4 +96,4 @@ if __name__ == "__main__":
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_librispeech()
|
||||
compute_fbank_tedlium()
|
||||
|
||||
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
"""
|
||||
Convert a transcript based on words to a list of BPE ids with the related BPE model.
|
||||
|
||||
For example, if we use 2 as the encoding id of <unk>, there are four examples:
|
||||
|
||||
texts = ['this is a <unk> day and in the room there are three <unk> laying in the bed']
|
||||
spm_ids = [[38, 33, 6, 2, 316, 8, 16, 5, 257, 193, 103, 61, 331, 2, 196, 21, 14, 16, 5, 47, 12]]
|
||||
|
||||
texts = ['<unk> this is a sunny day and in the room there are three people in the <unk>']
|
||||
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316, 8, 16, 5, 257, 193, 103, 61, 331, 107, 16, 5, 2]]
|
||||
|
||||
texts = ['<unk>']
|
||||
spm_ids = [[2]]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sentencepiece as spm
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--texts", type=List[str], help="The input transcripts list."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unk-id",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number id for the token '<unk>'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def convert_texts_into_ids(
|
||||
texts: List[str],
|
||||
unk_id: int,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Args:
|
||||
texts:
|
||||
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
|
||||
unk_id:
|
||||
A number id for the token '<unk>'.
|
||||
Returns:
|
||||
Return a integer list of bpe ids.
|
||||
"""
|
||||
y = []
|
||||
for text in texts:
|
||||
y_ids = []
|
||||
if "<unk>" in text:
|
||||
text_segments = text.split("<unk>")
|
||||
id_segments = sp.encode(text_segments, out_type=int)
|
||||
for i in range(len(id_segments)):
|
||||
if i != len(id_segments) - 1:
|
||||
y_ids.extend(id_segments[i] + [unk_id])
|
||||
else:
|
||||
y_ids.extend(id_segments[i])
|
||||
else:
|
||||
y_ids = sp.encode([text], out_type=int)[0]
|
||||
y.append(y_ids)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
texts = args.texts
|
||||
bpe_model = args.bpe_model
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(bpe_model)
|
||||
unk_id = sp.piece_to_id("<unk>")
|
||||
|
||||
y = convert_texts_into_ids(
|
||||
texts=texts,
|
||||
unk_id=unk_id,
|
||||
sp=sp,
|
||||
)
|
||||
logging.info(f"The input texts: {texts}")
|
||||
logging.info(f"The encoding ids: {y}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
egs/tedlium3/ASR/local/display_manifest_statistics.py
Normal file
89
egs/tedlium3/ASR/local/display_manifest_statistics.py
Normal file
@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in a manifest.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
|
||||
See the function `remove_short_and_long_utt()` in transducer/train.py
|
||||
for usage.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def describe(cuts) -> None:
|
||||
"""
|
||||
Print a message describing details about the ``CutSet`` - the number of cuts and the
|
||||
duration statistics, including the total duration and the percentage of speech segments.
|
||||
|
||||
Example output:
|
||||
Cuts count: 804789
|
||||
Total duration (hours): 1370.6
|
||||
Speech duration (hours): 1370.6 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 6.1
|
||||
std 3.1
|
||||
min 0.5
|
||||
25% 3.7
|
||||
50% 6.0
|
||||
75% 8.3
|
||||
99.5% 14.9
|
||||
99.9% 16.6
|
||||
max 33.3
|
||||
|
||||
In the above example, we set 15(>14.9) as the maximum duration of training samples.
|
||||
"""
|
||||
durations = np.array([c.duration for c in cuts])
|
||||
speech_durations = np.array(
|
||||
[s.duration for c in cuts for s in c.trimmed_supervisions]
|
||||
)
|
||||
total_sum = durations.sum()
|
||||
speech_sum = speech_durations.sum()
|
||||
print("Cuts count:", len(cuts))
|
||||
print(f"Total duration (hours): {total_sum / 3600:.1f}")
|
||||
print(
|
||||
f"Speech duration (hours): {speech_sum / 3600:.1f} ({speech_sum / total_sum:.1%})"
|
||||
)
|
||||
print("***")
|
||||
print("Duration statistics (seconds):")
|
||||
print(f"mean\t{np.mean(durations):.1f}")
|
||||
print(f"std\t{np.std(durations):.1f}")
|
||||
print(f"min\t{np.min(durations):.1f}")
|
||||
print(f"25%\t{np.percentile(durations, 25):.1f}")
|
||||
print(f"50%\t{np.median(durations):.1f}")
|
||||
print(f"75%\t{np.percentile(durations, 75):.1f}")
|
||||
print(f"99.5%\t{np.percentile(durations, 99.5):.1f}")
|
||||
print(f"99.9%\t{np.percentile(durations, 99.9):.1f}")
|
||||
print(f"max\t{np.max(durations):.1f}")
|
||||
|
||||
|
||||
def main():
|
||||
path = "./data/fbank/cuts_train.json.gz"
|
||||
# path = "./data/fbank/cuts_dev.json.gz"
|
||||
# path = "./data/fbank/cuts_test.json.gz"
|
||||
|
||||
cuts = load_manifest(path)
|
||||
describe(cuts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -151,6 +151,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Generate data for BPE training"
|
||||
cat data/lang_phone/train.text | cut -d " " -f 2-
|
||||
> $lang_dir/transcript_words.txt
|
||||
# remove the <unk> for transcript_words.txt
|
||||
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
|
||||
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
|
||||
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt
|
||||
|
||||
@ -7,7 +7,7 @@ https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
You can use the following command to start the training:
|
||||
|
||||
```bash
|
||||
cd egs/librispeech/ASR
|
||||
cd egs/tedlium3/ASR
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
@ -16,7 +16,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 250 \
|
||||
--lr-factor 2.5
|
||||
--max-duration 180 \
|
||||
--lr-factor 5.0
|
||||
```
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -17,7 +18,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
@ -43,12 +43,13 @@ def greedy_search(
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = model.decoder.unk_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -84,7 +85,7 @@ def greedy_search(
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
if y != blank_id:
|
||||
if y != blank_id and y != unk_id:
|
||||
hyp.append(y)
|
||||
decoder_input = torch.tensor(
|
||||
[hyp[-context_size:]], device=device
|
||||
@ -108,8 +109,9 @@ class Hypothesis:
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys
|
||||
log_prob: float
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
@ -118,7 +120,7 @@ class Hypothesis:
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
@ -130,11 +132,10 @@ class HypothesisList(object):
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
# def add(self, ys: List[int], log_prob: float):
|
||||
def add(self, hyp: Hypothesis):
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
@ -146,8 +147,10 @@ class HypothesisList(object):
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key]
|
||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
@ -159,7 +162,8 @@ class HypothesisList(object):
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(
|
||||
@ -171,6 +175,9 @@ class HypothesisList(object):
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
@ -181,7 +188,7 @@ class HypothesisList(object):
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: float) -> "HypothesisList":
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
@ -189,10 +196,10 @@ class HypothesisList(object):
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
that have `log_prob` being greater than the given `threshold`.
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for key, hyp in self._data.items():
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
@ -222,6 +229,201 @@ class HypothesisList(object):
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def run_decoder(
|
||||
ys: List[int],
|
||||
model: Transducer,
|
||||
decoder_cache: Dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Run the neural decoder model for a given hypothesis.
|
||||
|
||||
Args:
|
||||
ys:
|
||||
The current hypothesis.
|
||||
model:
|
||||
The transducer model.
|
||||
decoder_cache:
|
||||
Cache to save computations.
|
||||
Returns:
|
||||
Return a 1-D tensor of shape (decoder_out_dim,) containing
|
||||
output of `model.decoder`.
|
||||
"""
|
||||
context_size = model.decoder.context_size
|
||||
key = "_".join(map(str, ys[-context_size:]))
|
||||
if key in decoder_cache:
|
||||
return decoder_cache[key]
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_cache[key] = decoder_out
|
||||
|
||||
return decoder_out
|
||||
|
||||
|
||||
def run_joiner(
|
||||
key: str,
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
encoder_out_len: torch.Tensor,
|
||||
decoder_out_len: torch.Tensor,
|
||||
joint_cache: Dict[str, torch.Tensor],
|
||||
):
|
||||
"""Run the joint network given outputs from the encoder and decoder.
|
||||
|
||||
Args:
|
||||
key:
|
||||
A key into the `joint_cache`.
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A tensor of shape (1, 1, encoder_out_dim).
|
||||
decoder_out:
|
||||
A tensor of shape (1, 1, decoder_out_dim).
|
||||
encoder_out_len:
|
||||
A tensor with value [1].
|
||||
decoder_out_len:
|
||||
A tensor with value [1].
|
||||
joint_cache:
|
||||
A dict to save computations.
|
||||
Returns:
|
||||
Return a tensor from the output of log-softmax.
|
||||
Its shape is (vocab_size,).
|
||||
"""
|
||||
if key in joint_cache:
|
||||
return joint_cache[key]
|
||||
|
||||
logits = model.joiner(
|
||||
encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len,
|
||||
decoder_out_len,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
|
||||
joint_cache[key] = log_prob
|
||||
|
||||
return log_prob
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = model.decoder.unk_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||
# ys_log_probs is of shape (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
)
|
||||
# decoder_input is of shape (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, -1
|
||||
)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len.expand(decoder_out.size(0)),
|
||||
decoder_out_len.expand(decoder_out.size(0)),
|
||||
)
|
||||
# logits is of shape (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||
|
||||
# topk_hyp_indexes are indexes into `A`
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token != blank_id and new_token != unk_id:
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -247,6 +449,7 @@ def beam_search(
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = model.decoder.unk_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
@ -261,7 +464,12 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -281,58 +489,43 @@ def beam_search(
|
||||
|
||||
joint_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
||||
# of hypotheses in A
|
||||
|
||||
while True:
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
|
||||
cached_key = y_star.key
|
||||
decoder_out = run_decoder(
|
||||
ys=y_star.ys, model=model, decoder_cache=decoder_cache
|
||||
)
|
||||
|
||||
if cached_key not in decoder_cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-context_size:]], device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_cache[cached_key] = decoder_out
|
||||
else:
|
||||
decoder_out = decoder_cache[cached_key]
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len,
|
||||
decoder_out_len,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Ccale the blank posterior
|
||||
|
||||
log_prob = logits.log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
joint_cache[cached_key] = log_prob
|
||||
else:
|
||||
log_prob = joint_cache[cached_key]
|
||||
key = "_".join(map(str, y_star.ys[-context_size:]))
|
||||
key += f"-t-{t}"
|
||||
log_prob = run_joiner(
|
||||
key=key,
|
||||
model=model,
|
||||
encoder_out=current_encoder_out,
|
||||
decoder_out=decoder_out,
|
||||
encoder_out_len=encoder_out_len,
|
||||
decoder_out_len=decoder_out_len,
|
||||
joint_cache=joint_cache,
|
||||
)
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
if i == blank_id:
|
||||
for idx in range(values.size(0)):
|
||||
i = indices[idx].item()
|
||||
if i == blank_id or i == unk_id:
|
||||
continue
|
||||
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
|
||||
new_log_prob = y_star.log_prob + values[idx]
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
|
||||
@ -615,7 +615,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
""" # noqa
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
@ -635,7 +635,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
@ -643,7 +643,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
@ -653,7 +653,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
@ -662,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
@ -671,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_b = _b[_start:_end]
|
||||
k = nn.functional.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias # noqa
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
@ -687,12 +687,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
or attn_mask.dtype == torch.float16
|
||||
or attn_mask.dtype == torch.uint8
|
||||
or attn_mask.dtype == torch.bool
|
||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format( # noqa
|
||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||
attn_mask.dtype
|
||||
)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead." # noqa
|
||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
@ -725,7 +725,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." # noqa
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
@ -760,7 +760,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # noqa
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
@ -832,7 +832,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py # noqa
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -19,16 +20,16 @@
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--epoch 29 \
|
||||
--avg 15 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--epoch 29 \
|
||||
--avg 15 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
@ -45,8 +46,8 @@ from typing import Dict, List, Tuple
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import beam_search, greedy_search
|
||||
from asr_datamodule import TedLiumAsrDataModule
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -77,7 +78,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=13,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
@ -169,6 +170,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
blank_id=params.blank_id,
|
||||
unk_id=params.unk_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
@ -256,6 +258,10 @@ def decode_one_batch(
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -382,14 +388,18 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
TedLiumAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
@ -413,6 +423,7 @@ def main():
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -439,16 +450,12 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
tedlium = TedLiumAsrDataModule(args)
|
||||
test_cuts = tedlium.test_cuts()
|
||||
test_dl = tedlium.test_dataloaders(test_cuts)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
test_sets = ["test"]
|
||||
test_dl = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -37,6 +38,7 @@ class Decoder(nn.Module):
|
||||
vocab_size: int,
|
||||
embedding_dim: int,
|
||||
blank_id: int,
|
||||
unk_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
@ -47,6 +49,8 @@ class Decoder(nn.Module):
|
||||
Dimension of the input embedding.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
unk_id:
|
||||
The ID of the unk symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
@ -58,6 +62,7 @@ class Decoder(nn.Module):
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.unk_id = unk_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
|
||||
@ -120,7 +120,6 @@ class Transducer(nn.Module):
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
from_log_softmax=False,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
@ -50,7 +50,7 @@ import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -167,6 +167,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
blank_id=params.blank_id,
|
||||
unk_id=params.unk_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
@ -230,6 +231,7 @@ def main():
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
@ -300,6 +302,10 @@ def main():
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -18,7 +19,7 @@
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
cd icefall/egs/tedlium3/ASR
|
||||
python ./transducer_stateless/test_decoder.py
|
||||
"""
|
||||
|
||||
@ -29,6 +30,7 @@ from decoder import Decoder
|
||||
def test_decoder():
|
||||
vocab_size = 3
|
||||
blank_id = 0
|
||||
unk_id = 2
|
||||
embedding_dim = 128
|
||||
context_size = 4
|
||||
|
||||
@ -36,6 +38,7 @@ def test_decoder():
|
||||
vocab_size=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
blank_id=blank_id,
|
||||
unk_id=unk_id,
|
||||
context_size=context_size,
|
||||
)
|
||||
N = 100
|
||||
|
||||
@ -26,9 +26,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 250 \
|
||||
--lr-factor 2.5
|
||||
--max-duration 180 \
|
||||
--lr-factor 5.0
|
||||
"""
|
||||
|
||||
|
||||
@ -56,6 +55,8 @@ from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
@ -233,6 +234,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
blank_id=params.blank_id,
|
||||
unk_id=params.unk_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
@ -379,7 +381,9 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
|
||||
|
||||
texts = batch["supervisions"]["text"][: feature.size(0)]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
|
||||
unk_id = params.unk_id
|
||||
y = convert_texts_into_ids(texts, unk_id, sp=sp)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -565,6 +569,7 @@ def run(rank, world_size, args):
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user