Sorted imports for transducer_emformer/streaming_feature_extractor.py

This commit is contained in:
yaozengwei 2022-04-22 11:04:50 +08:00
parent 8fde2acd97
commit e97c9fbdbf

View File

@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
from beam_search import HypothesisList
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from typing import List, Optional
import torch
def _create_streaming_feature_extractor() -> OnlineFeature:
@ -41,6 +42,15 @@ def _create_streaming_feature_extractor() -> OnlineFeature:
class FeatureExtractionStream(object):
def __init__(self, context_size: int, decoding_method: str) -> None:
"""
Args:
context_size:
Context size of the RNN-T decoder model.
decoding_method:
Decoding method. The possible values are:
- greedy_search
- modified_beam_search
"""
self.feature_extractor = _create_streaming_feature_extractor()
# It contains a list of 1-D tensors representing the feature frames.
self.feature_frames: List[torch.Tensor] = []