From 778a9c5903858b5d60f9ea17b151bfc0590df4a0 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Thu, 14 Nov 2024 10:07:37 -0500 Subject: [PATCH 1/4] Upgrade faster-whisper 1.1.0rc0 Signed-off-by: makaveli10 --- requirements/server.txt | 6 +- whisper_live/server.py | 7 +- whisper_live/transcriber.py | 1494 ++++++++++++++++++++++++++++------- 3 files changed, 1209 insertions(+), 298 deletions(-) diff --git a/requirements/server.txt b/requirements/server.txt index 402f967..8e10592 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -1,4 +1,4 @@ -faster-whisper==1.0.1 +faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/8f01aee36b562e6be537e0341cdd40dc8bed33a7.tar.gz websockets onnxruntime==1.16.0 numba @@ -9,5 +9,5 @@ scipy jiwer evaluate numpy<2 -tiktoken==0.3.3 -openai-whisper==20231117 \ No newline at end of file +tiktoken==0.8.0 +openai-whisper==20240930 \ No newline at end of file diff --git a/whisper_live/server.py b/whisper_live/server.py index b68df5a..e8f241c 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -780,8 +780,11 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli super().__init__(client_uid, websocket) self.model_sizes = [ "tiny", "tiny.en", "base", "base.en", "small", "small.en", - "medium", "medium.en", "large-v2", "large-v3", + "medium", "medium.en", "large-v2", "large-v3", "distil-small.en", + "distil-medium.en", "distil-large-v2", "distil-large-v3", + "large-v3-turbo", "turbo" ] + if not os.path.exists(model): self.model_size_or_path = self.check_valid_model(model) else: @@ -789,7 +792,7 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli self.language = "en" if self.model_size_or_path.endswith("en") else language self.task = task self.initial_prompt = initial_prompt - self.vad_parameters = vad_parameters or {"threshold": 0.5} + self.vad_parameters = vad_parameters or {"onset": 0.5} self.no_speech_thresh = 0.45 device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/whisper_live/transcriber.py b/whisper_live/transcriber.py index 7424570..a648063 100644 --- a/whisper_live/transcriber.py +++ b/whisper_live/transcriber.py @@ -4,14 +4,22 @@ import json import logging import os +import random import zlib +from collections import Counter, defaultdict +from dataclasses import asdict, dataclass from inspect import signature -from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union +from math import ceil +from typing import BinaryIO, Iterable, List, Optional, Tuple, Union +from warnings import warn import ctranslate2 import numpy as np import tokenizers +import torch + +from tqdm import tqdm from faster_whisper.audio import decode_audio, pad_or_trim from faster_whisper.feature_extractor import FeatureExtractor @@ -22,31 +30,52 @@ VadOptions, collect_chunks, get_speech_timestamps, + merge_segments, ) -class Word(NamedTuple): +@dataclass +class Word: start: float end: float word: str probability: float + def _asdict(self): + warn( + "Word._asdict() method is deprecated, use dataclasses.asdict(Word) instead", + DeprecationWarning, + 2, + ) + return asdict(self) + -class Segment(NamedTuple): +@dataclass +class Segment: id: int seek: int start: float end: float text: str tokens: List[int] - temperature: float avg_logprob: float compression_ratio: float no_speech_prob: float words: Optional[List[Word]] + temperature: Optional[float] = 1.0 + + def _asdict(self): + warn( + "Segment._asdict() method is deprecated, use dataclasses.asdict(Segment) instead", + DeprecationWarning, + 2, + ) + return asdict(self) -class TranscriptionOptions(NamedTuple): +# Added additional parameters for multilingual videos and fixes below +@dataclass +class TranscriptionOptions: beam_size: int best_of: int patience: float @@ -54,6 +83,7 @@ class TranscriptionOptions(NamedTuple): repetition_penalty: float no_repeat_ngram_size: int log_prob_threshold: Optional[float] + log_prob_low_threshold: Optional[float] no_speech_threshold: Optional[float] compression_ratio_threshold: Optional[float] condition_on_previous_text: bool @@ -68,12 +98,16 @@ class TranscriptionOptions(NamedTuple): word_timestamps: bool prepend_punctuations: str append_punctuations: str + multilingual: bool + output_language: Optional[str] max_new_tokens: Optional[int] clip_timestamps: Union[str, List[float]] hallucination_silence_threshold: Optional[float] + hotwords: Optional[str] -class TranscriptionInfo(NamedTuple): +@dataclass +class TranscriptionInfo: language: str language_probability: float duration: float @@ -83,6 +117,418 @@ class TranscriptionInfo(NamedTuple): vad_options: VadOptions +# The code below is originally from HF pipeline and is used in whisper-x +# (https://github.com/m-bain/whisperX) and adapted for faster_whisper + + +class BatchedInferencePipeline: + """ + Huggingface Pipeline wrapper for WhisperModel. + Copyright (c) 2022, Max Bain + All rights reserved. + Modified by Mobius Labs GmbH + """ + + def __init__( + self, + model, + options: Optional[TranscriptionOptions] = None, + tokenizer=None, + language: Optional[str] = None, + ): + self.model: WhisperModel = model + self.tokenizer = tokenizer + self.options = options + self.preset_language = language + self.last_speech_timestamp = 0.0 + + def forward(self, features, chunks_metadata, **forward_params): + encoder_output, outputs = self.model.generate_segment_batched( + features, self.tokenizer, forward_params + ) + + segmented_outputs = [] + segment_sizes = [] + for chunk_metadata, output in zip(chunks_metadata, outputs): + duration = chunk_metadata["end_time"] - chunk_metadata["start_time"] + segment_size = int(ceil(duration) * self.model.frames_per_second) + segment_sizes.append(segment_size) + ( + subsegments, + seek, + single_timestamp_ending, + ) = self.model._split_segments_by_timestamps( + tokenizer=self.tokenizer, + tokens=output["tokens"], + time_offset=chunk_metadata["start_time"], + segment_size=segment_size, + segment_duration=duration, + seek=0, + ) + segmented_outputs.append( + [ + dict( + text=self.tokenizer.decode(subsegment["tokens"]), + avg_logprob=output["avg_logprob"], + no_speech_prob=output["no_speech_prob"], + tokens=subsegment["tokens"], + start=subsegment["start"], + end=subsegment["end"], + compression_ratio=get_compression_ratio( + self.tokenizer.decode(subsegment["tokens"]) + ), + ) + for subsegment in subsegments + ] + ) + if forward_params["word_timestamps"]: + self.last_speech_timestamp = self.model.add_word_timestamps( + segmented_outputs, + self.tokenizer, + encoder_output, + segment_sizes, + forward_params["prepend_punctuations"], + forward_params["append_punctuations"], + self.last_speech_timestamp, + ) + + return segmented_outputs + + def get_language_and_tokenizer( + self, audio, task: Optional[str] = None, language: Optional[str] = None + ): + all_language_probs = None + language_probability = 1.0 + + if self.tokenizer is None: + if not language: + ( + language, + language_probability, + all_language_probs, + ) = self.model.detect_language(audio) + task = task or "transcribe" + self.tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + else: + if task is not None: + self.tokenizer.task = self.tokenizer.tokenizer.token_to_id( + f"<|{task}|>" + ) + + if language is not None: + self.tokenizer.language = self.tokenizer.tokenizer.token_to_id( + f"<|{language}|>" + ) + self.tokenizer.language_code = language + + return language, language_probability, task, all_language_probs + + def transcribe( + self, + audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], + language: Optional[str] = None, + task: str = None, + log_progress: bool = False, + beam_size: int = 5, + best_of: int = 5, + patience: float = 1, + length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ], + compression_ratio_threshold: Optional[float] = 2.4, + log_prob_threshold: Optional[float] = -1.0, + log_prob_low_threshold: Optional[float] = None, + no_speech_threshold: Optional[float] = 0.6, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, + prefix: Optional[str] = None, + suppress_blank: bool = True, + suppress_tokens: Optional[List[int]] = [-1], + without_timestamps: bool = True, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + vad_filter: bool = True, + vad_parameters: Optional[Union[dict, VadOptions]] = None, + max_new_tokens: Optional[int] = None, + chunk_length: Optional[int] = None, + clip_timestamps: Optional[List[dict]] = None, + batch_size: int = 16, + hotwords: Optional[str] = None, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: + """transcribe audio in chunks in batched fashion and return with language info. + + Arguments: + audio: Path to the input file (or a file-like object), or the audio waveform. + language: The language spoken in the audio. It should be a language code such + as "en" or "fr". If not set, the language will be detected in the first 30 seconds + of audio. + task: Task to execute (transcribe or translate). + log_progress: whether to show progress bar or not. + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). + temperature: Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `log_prob_threshold`. + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + log_prob_low_threshold: This parameter alone is sufficient to skip an output text, + whereas log_prob_threshold also looks for appropriate no_speech_threshold value. + This value should be less than log_prob_threshold. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. + prefix: Optional text to provide as a prefix for the first window. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in `tokenizer.non_speech_tokens()`. + without_timestamps: Only sample text tokens. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + Set as False. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, + the maximum will be set by the default max_length. + chunk_length: The length of audio segments. If it is not None, it will overwrite the + default chunk_length of the FeatureExtractor. + clip_timestamps: Optionally provide list of dictionaries each containing "start" and + "end" keys that specify the start and end of the voiced region within + `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used. + batch_size: the maximum number of parallel requests to model for decoding. + hotwords: + Hotwords/hint phrases to the model. Has no effect if prefix is not None. + + Static params: (Fixed for batched version) + max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. + multilingual: If True, perform transcription on multilingual videos. Set as False. + output_language: Valid only if multilingual is set to True. + Specifies the string representing the output language. One of + 'en' (English) or 'hybrid' (code-switched transcription). set as None. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. Set as False + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. Set at 0.5 + #TODO: support "hallucination_silence_threshold" when "word_timestamps=True" + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected. set as None. + + unused: + language_detection_threshold: If the maximum probability of the language tokens is + higher than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. + + + Returns: + A tuple with: + + - a generator over transcribed segments + - an instance of TranscriptionInfo + """ + + sampling_rate = self.model.feature_extractor.sampling_rate + + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + elif not isinstance(audio, torch.Tensor): + audio = decode_audio(audio, sampling_rate=sampling_rate) + duration = audio.shape[0] / sampling_rate + + chunk_length = chunk_length or self.model.feature_extractor.chunk_length + # if no segment split is provided, use vad_model and generate segments + if not clip_timestamps: + if vad_filter: + if vad_parameters is None: + vad_parameters = VadOptions( + max_speech_duration_s=chunk_length, + min_silence_duration_ms=160, + ) + elif isinstance(vad_parameters, dict): + if "max_speech_duration_s" in vad_parameters.keys(): + vad_parameters.pop("max_speech_duration_s") + + vad_parameters = VadOptions( + **vad_parameters, max_speech_duration_s=chunk_length + ) + + active_segments = get_speech_timestamps(audio, vad_parameters) + clip_timestamps = merge_segments(active_segments, vad_parameters) + # run the audio if it is less than 30 sec even without clip_timestamps + elif duration < chunk_length: + clip_timestamps = [{"start": 0, "end": audio.shape[0]}] + else: + raise RuntimeError( + "No clip timestamps found. " + "Set 'vad_filter' to True or provide 'clip_timestamps'." + ) + if self.model.model.is_multilingual: + language = language or self.preset_language + elif language != "en": + if language is not None: + self.model.logger.warning( + f"English-only model is used, but {language} language is" + " chosen, setting language to 'en'." + ) + language = "en" + + ( + language, + language_probability, + task, + all_language_probs, + ) = self.get_language_and_tokenizer(audio, task, language) + + duration_after_vad = ( + sum((segment["end"] - segment["start"]) for segment in clip_timestamps) + / sampling_rate + ) + + # batched options: see the difference with default options in WhisperModel + batched_options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + log_prob_threshold=log_prob_threshold, + log_prob_low_threshold=log_prob_low_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + temperatures=( + temperature if isinstance(temperature, (list, tuple)) else [temperature] + ), + initial_prompt=initial_prompt, + prefix=prefix, + suppress_blank=suppress_blank, + suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + max_new_tokens=max_new_tokens, + hotwords=hotwords, + word_timestamps=word_timestamps, + hallucination_silence_threshold=None, + condition_on_previous_text=False, + clip_timestamps="0", + prompt_reset_on_temperature=0.5, + multilingual=False, + output_language=None, + without_timestamps=without_timestamps, + max_initial_timestamp=0.0, + ) + + info = TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=duration, + duration_after_vad=duration_after_vad, + transcription_options=batched_options, + vad_options=None, + all_language_probs=all_language_probs, + ) + + audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) + to_cpu = ( + self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 + ) + features = ( + torch.stack( + [ + pad_or_trim( + self.model.feature_extractor(chunk, to_cpu=to_cpu)[ + ..., + : chunk.shape[0] // self.model.feature_extractor.hop_length, + ] + ) + for chunk in audio_chunks + ] + ) + if duration_after_vad + else [] + ) + + segments = self._batched_segments_generator( + features, + chunks_metadata, + batch_size, + batched_options, + log_progress, + ) + + return segments, info + + def _batched_segments_generator( + self, features, chunks_metadata, batch_size, options, log_progress + ): + pbar = tqdm(total=len(features), disable=not log_progress, position=0) + seg_idx = 0 + for i in range(0, len(features), batch_size): + results = self.forward( + features[i : i + batch_size], + chunks_metadata[i : i + batch_size], + **asdict(options), + ) + + for result in results: + for segment in result: + seg_idx += 1 + yield Segment( + seek=int(result[-1]["end"] * self.model.frames_per_second), + id=seg_idx, + text=segment["text"], + start=round(segment["start"], 3), + end=round(segment["end"], 3), + words=( + None + if not options.word_timestamps + else [Word(**word) for word in segment["words"]] + ), + tokens=segment["tokens"], + avg_logprob=segment["avg_logprob"], + no_speech_prob=segment["no_speech_prob"], + compression_ratio=segment["compression_ratio"], + ) + + pbar.update(1) + + pbar.close() + # revert the tokenizer if multilingual inference is enabled + if self.preset_language is None: + self.tokenizer = None + self.last_speech_timestamp = 0.0 + + class WhisperModel: def __init__( self, @@ -94,14 +540,17 @@ def __init__( num_workers: int = 1, download_root: Optional[str] = None, local_files_only: bool = False, + files: dict = None, + **model_kwargs, ): """Initializes the Whisper model. Args: model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en, - small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a - converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub. - When a size or a model ID is configured, the converted model is downloaded + small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, + large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo), + a path to a converted model directory, or a CTranslate2-converted Whisper model ID from + the HF Hub. When a size or a model ID is configured, the converted model is downloaded from the Hugging Face Hub. device: Device to use for computation ("cpu", "cuda", "auto"). device_index: Device ID to use. @@ -120,10 +569,18 @@ def __init__( are saved in the standard Hugging Face cache directory. local_files_only: If True, avoid downloading the file and return the path to the local cached file if it exists. + files: Load model files from the memory. This argument is a dictionary mapping file names + to file contents as file-like or bytes objects. If this is set, model_path acts as an + identifier for this model. """ self.logger = get_logger() - if os.path.isdir(model_size_or_path): + tokenizer_bytes, preprocessor_bytes = None, None + if files: + model_path = model_size_or_path + tokenizer_bytes = files.pop("tokenizer.json", None) + preprocessor_bytes = files.pop("preprocessor_config.json", None) + elif os.path.isdir(model_size_or_path): model_path = model_size_or_path else: model_path = download_model( @@ -131,34 +588,43 @@ def __init__( local_files_only=local_files_only, cache_dir=download_root, ) - + self.device = device + # set the random seed to make sure consistency across runs + ctranslate2.set_random_seed(42) self.model = ctranslate2.models.Whisper( model_path, - device=device, + device=self.device, device_index=device_index, compute_type=compute_type, intra_threads=cpu_threads, inter_threads=num_workers, + files=files, + **model_kwargs, ) tokenizer_file = os.path.join(model_path, "tokenizer.json") - if os.path.isfile(tokenizer_file): + if tokenizer_bytes: + self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes) + elif os.path.isfile(tokenizer_file): self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file) else: self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained( "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) - - self.feat_kwargs = self._get_feature_kwargs(model_path) - self.feature_extractor = FeatureExtractor(**self.feat_kwargs) - self.num_samples_per_token = self.feature_extractor.hop_length * 2 + self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes) + self.feature_extractor = FeatureExtractor( + **self.feat_kwargs, device=self.device + ) + self.input_stride = 2 + self.num_samples_per_token = ( + self.feature_extractor.hop_length * self.input_stride + ) self.frames_per_second = ( self.feature_extractor.sampling_rate // self.feature_extractor.hop_length ) self.tokens_per_second = ( self.feature_extractor.sampling_rate // self.num_samples_per_token ) - self.input_stride = 2 self.time_precision = 0.02 self.max_length = 448 @@ -167,25 +633,27 @@ def supported_languages(self) -> List[str]: """The languages supported by the model.""" return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"] - def _get_feature_kwargs(self, model_path) -> dict: - preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json") + def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: config = {} - if os.path.isfile(preprocessor_config_file): - try: - with open(preprocessor_config_file, "r", encoding="utf-8") as json_file: - config = json.load(json_file) - valid_keys = signature(FeatureExtractor.__init__).parameters.keys() - config = {k: v for k, v in config.items() if k in valid_keys} - except json.JSONDecodeError as e: - self.logger.warning( - "Could not load preprocessor_config.json: %s", str(e) - ) + try: + config_path = os.path.join(model_path, "preprocessor_config.json") + if preprocessor_bytes: + config = json.loads(preprocessor_bytes) + elif os.path.isfile(config_path): + with open(config_path, "r", encoding="utf-8") as file: + config = json.load(file) + else: + return config + valid_keys = signature(FeatureExtractor.__init__).parameters.keys() + return {k: v for k, v in config.items() if k in valid_keys} + except json.JSONDecodeError as e: + self.logger.warning("Could not load preprocessor config: %s", e) return config - def transcribe( # noqa: C901 + def transcribe( self, - audio: Union[str, BinaryIO, np.ndarray], + audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], language: Optional[str] = None, task: str = "transcribe", beam_size: int = 5, @@ -204,6 +672,7 @@ def transcribe( # noqa: ], compression_ratio_threshold: Optional[float] = 2.4, log_prob_threshold: Optional[float] = -1.0, + log_prob_low_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, prompt_reset_on_temperature: float = 0.5, @@ -216,12 +685,17 @@ def transcribe( # noqa: word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + multilingual: bool = False, + output_language: Optional[str] = None, vad_filter: bool = False, vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, chunk_length: Optional[int] = None, clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + hotwords: Optional[str] = None, + language_detection_threshold: Optional[float] = 0.5, + language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -245,6 +719,9 @@ def transcribe( # noqa: treat as failed. log_prob_threshold: If the average log probability over sampled tokens is below this value, treat as failed. + log_prob_low_threshold: This parameter alone is sufficient to skip an output text, + wheras log_prob_threshold also looks for appropriate no_speech_threshold value. + This value should be less than log_prob_threshold. no_speech_threshold: If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below `log_prob_threshold`, consider the segment as silent. @@ -259,7 +736,7 @@ def transcribe( # noqa: prefix: Optional text to provide as a prefix for the first window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in the model config.json file. + of symbols as defined in `tokenizer.non_speech_tokens()`. without_timestamps: Only sample text tokens. max_initial_timestamp: The initial timestamp cannot be later than this. word_timestamps: Extract word-level timestamps using the cross-attention pattern @@ -268,6 +745,12 @@ def transcribe( # noqa: with the next word append_punctuations: If word_timestamps is True, merge these punctuation symbols with the previous word + multilingual: If True, perform transcription on multilingual videos + and return the transcript based + on the 'output_language' flag. + output_language: Valid only if multilingual is set to True. + Specifies the string representing the output language. One of + 'en' (English) or 'hybrid' (code-switched transcription). vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio without speech. This step is using the Silero VAD model https://github.com/snakers4/silero-vad. @@ -277,22 +760,30 @@ def transcribe( # noqa: the maximum will be set by the default max_length. chunk_length: The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor. - clip_timestamps: Union[str, List[float]] + clip_timestamps: Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. The last end timestamp defaults to the end of the file. - hallucination_silence_threshold: Optional[float] + vad_filter will be ignored if clip_timestamps is used. + hallucination_silence_threshold: When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected - + hotwords: + Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. + language_detection_threshold: If the maximum probability of the language tokens is higher + than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. Returns: A tuple with: - a generator over transcribed segments - an instance of TranscriptionInfo """ + sampling_rate = self.feature_extractor.sampling_rate - if not isinstance(audio, np.ndarray): + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + elif not isinstance(audio, torch.Tensor): audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate @@ -302,13 +793,14 @@ def transcribe( # noqa: "Processing audio with duration %s", format_timestamp(duration) ) - if vad_filter: + if vad_filter and clip_timestamps == "0": if vad_parameters is None: vad_parameters = VadOptions() elif isinstance(vad_parameters, dict): vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) - audio = collect_chunks(audio, speech_chunks) + audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio = torch.cat(audio_chunks, dim=0) duration_after_vad = audio.shape[0] / sampling_rate self.logger.info( @@ -334,26 +826,81 @@ def transcribe( # noqa: if audio.shape[0] == 0: return None, None - - features = self.feature_extractor(audio, chunk_length=chunk_length) + + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + features = self.feature_extractor( + audio, chunk_length=chunk_length, to_cpu=to_cpu + ) encoder_output = None all_language_probs = None + # setting output_language for multilingual videos + if multilingual: + if output_language is None: + output_language = "en" + elif output_language not in ["en", "hybrid"]: + raise ValueError("Output language needs to be one of 'en'/'hybrid'.") + + # detecting the language if not provided if language is None: if not self.model.is_multilingual: language = "en" language_probability = 1 else: - segment = features[:, : self.feature_extractor.nb_max_frames] - encoder_output = self.encode(segment) - # results is a list of tuple[str, float] with language names and - # probabilities. - results = self.model.detect_language(encoder_output)[0] - # Parse language names to strip out markers - all_language_probs = [(token[2:-2], prob) for (token, prob) in results] - # Get top language token and probability - language, language_probability = all_language_probs[0] + if ( + language_detection_segments is None + or language_detection_segments < 1 + ): + language_detection_segments = 1 + start_timestamp = ( + float(clip_timestamps.split(",")[0]) + if isinstance(clip_timestamps, str) + else clip_timestamps[0] + ) + content_frames = ( + features.shape[-1] - self.feature_extractor.nb_max_frames + ) + seek = ( + int(start_timestamp * self.frames_per_second) + if start_timestamp * self.frames_per_second < content_frames + else 0 + ) + end_frames = min( + seek + + self.feature_extractor.nb_max_frames + * language_detection_segments, + content_frames, + ) + detected_language_info = {} + while seek <= end_frames: + segment = features[ + :, seek : seek + self.feature_extractor.nb_max_frames + ] + encoder_output = self.encode(pad_or_trim(segment)) + # results is a list of tuple[str, float] with language names and + # probabilities. + results = self.model.detect_language(encoder_output)[0] + # Parse language names to strip out markers + all_language_probs = [ + (token[2:-2], prob) for (token, prob) in results + ] + # Get top language token and probability + language, language_probability = all_language_probs[0] + if language_probability > language_detection_threshold: + break + detected_language_info.setdefault(language, []).append( + language_probability + ) + seek += segment.shape[-1] + else: + # If no language detected for all segments, the majority vote of the highest + # projected languages for all segments is used to determine the language. + language = max( + detected_language_info, + key=lambda lang: len(detected_language_info[lang]), + ) + language_probability = max(detected_language_info[language]) self.logger.info( "Detected language '%s' with probability %.2f", @@ -385,6 +932,7 @@ def transcribe( # noqa: repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, log_prob_threshold=log_prob_threshold, + log_prob_low_threshold=log_prob_low_threshold, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, condition_on_previous_text=condition_on_previous_text, @@ -395,15 +943,22 @@ def transcribe( # noqa: initial_prompt=initial_prompt, prefix=prefix, suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens), + suppress_tokens=( + get_suppressed_tokens(tokenizer, suppress_tokens) + if suppress_tokens + else suppress_tokens + ), without_timestamps=without_timestamps, max_initial_timestamp=max_initial_timestamp, word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, + multilingual=multilingual, + output_language=output_language, max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hallucination_silence_threshold=hallucination_silence_threshold, + hotwords=hotwords, ) segments = self.generate_segments(features, tokenizer, options, encoder_output) @@ -420,12 +975,90 @@ def transcribe( # noqa: vad_options=vad_parameters, all_language_probs=all_language_probs, ) - return segments, info + def _split_segments_by_timestamps( + self, + tokenizer: Tokenizer, + tokens: List[int], + time_offset: float, + segment_size: int, + segment_duration: float, + seek: int, + ) -> List[List[int]]: + current_segments = [] + single_timestamp_ending = ( + len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] + ) + + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin + ] + + if len(consecutive_timestamps) > 0: + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin + end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = time_offset + end_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_position = ( + tokens[last_slice - 1] - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * self.input_stride + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= tokenizer.timestamp_begin + ] + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin + duration = last_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) + ) + + seek += segment_size + + return current_segments, seek, single_timestamp_ending + def generate_segments( self, - features: np.ndarray, + features: torch.Tensor, tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output: Optional[ctranslate2.StorageView] = None, @@ -434,7 +1067,7 @@ def generate_segments( content_duration = float(content_frames * self.feature_extractor.time_per_frame) if isinstance(options.clip_timestamps, str): - TranscriptionOptions.clip_timestamps = [ + options.clip_timestamps = [ float(ts) for ts in ( options.clip_timestamps.split(",") @@ -442,6 +1075,7 @@ def generate_segments( else [] ) ] + seek_points: List[int] = [ round(ts * self.frames_per_second) for ts in options.clip_timestamps ] @@ -496,9 +1130,9 @@ def generate_segments( content_frames - seek, seek_clip_end - seek, ) - segment = features[:, seek:seek + segment_size] + segment = features[:, seek : seek + segment_size] segment_duration = segment_size * self.feature_extractor.time_per_frame - segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames) + segment = pad_or_trim(segment) if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug( @@ -506,11 +1140,34 @@ def generate_segments( ) previous_tokens = all_tokens[prompt_reset_since:] + + if encoder_output is None: + encoder_output = self.encode(segment) + + # Perform language detection at every segment to update task based on output language, + # if the language is english, task is transcribe, + # else the task is translate to english (default) + # or transcribe if 'output_language' is 'hybrid'. + if options.multilingual: + results = self.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + if options.output_language == "en" and language != "en": + task = "translate" + else: + task = "transcribe" + + # Update tokenizer based on task and language + tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>") + tokenizer.language = tokenizer.tokenizer.token_to_id(language_token) + tokenizer.language_code = language + # Update prompt based on task and language prompt = self.get_prompt( tokenizer, previous_tokens, without_timestamps=options.without_timestamps, prefix=options.prefix if seek == 0 else None, + hotwords=options.hotwords, ) if seek > 0 or encoder_output is None: @@ -541,6 +1198,18 @@ def generate_segments( options.no_speech_threshold, ) + # Skip if the logprob is very low (below the threshold value), + # despite no_speech_prob being low (ex: Too ambiguous outputs) + if options.log_prob_low_threshold: + if avg_logprob < options.log_prob_low_threshold: + should_skip = True + self.logger.debug( + "log prob low threshold is met (%f > %f)", + avg_logprob, + options.log_prob_low_threshold, + ) + + if should_skip: # fast-forward to the next segment boundary seek += segment_size continue @@ -548,7 +1217,6 @@ def generate_segments( tokens = result.sequences_ids[0] previous_seek = seek - current_segments = [] # anomalous words are very long/short/improbable def word_anomaly_score(word: dict) -> float: @@ -574,83 +1242,22 @@ def is_segment_anomaly(segment: Optional[dict]) -> bool: def next_words_segment(segments: List[dict]) -> Optional[dict]: return next((s for s in segments if s["words"]), None) - single_timestamp_ending = ( - len(tokens) >= 2 - and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] + ( + current_segments, + seek, + single_timestamp_ending, + ) = self._split_segments_by_timestamps( + tokenizer=tokenizer, + tokens=tokens, + time_offset=time_offset, + segment_size=segment_size, + segment_duration=segment_duration, + seek=seek, ) - consecutive_timestamps = [ - i - for i in range(len(tokens)) - if i > 0 - and tokens[i] >= tokenizer.timestamp_begin - and tokens[i - 1] >= tokenizer.timestamp_begin - ] - - if len(consecutive_timestamps) > 0: - slices = list(consecutive_timestamps) - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = ( - sliced_tokens[0] - tokenizer.timestamp_begin - ) - end_timestamp_position = ( - sliced_tokens[-1] - tokenizer.timestamp_begin - ) - start_time = ( - time_offset + start_timestamp_position * self.time_precision - ) - end_time = ( - time_offset + end_timestamp_position * self.time_precision - ) - - current_segments.append( - dict( - seek=seek, - start=start_time, - end=end_time, - tokens=sliced_tokens, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_position = ( - tokens[last_slice - 1] - tokenizer.timestamp_begin - ) - seek += last_timestamp_position * self.input_stride - - else: - duration = segment_duration - timestamps = [ - token for token in tokens if token >= tokenizer.timestamp_begin - ] - if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: - last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin - duration = last_timestamp_position * self.time_precision - - current_segments.append( - dict( - seek=seek, - start=time_offset, - end=time_offset + duration, - tokens=tokens, - ) - ) - - seek += segment_size - if options.word_timestamps: self.add_word_timestamps( - current_segments, + [current_segments], tokenizer, encoder_output, segment_size, @@ -658,7 +1265,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: options.append_punctuations, last_speech_timestamp=last_speech_timestamp, ) - if not single_timestamp_ending: last_word_end = get_end(current_segments) if last_word_end is not None and last_word_end > time_offset: @@ -685,7 +1291,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: continue if is_segment_anomaly(segment): next_segment = next_words_segment( - current_segments[si + 1:] + current_segments[si + 1 :] ) if next_segment is not None: hal_next_start = next_segment["words"][0]["start"] @@ -715,7 +1321,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: last_word_end = get_end(current_segments) if last_word_end is not None: last_speech_timestamp = last_word_end - for segment in current_segments: tokens = segment["tokens"] text = tokenizer.decode(tokens) @@ -758,12 +1363,13 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: prompt_reset_since = len(all_tokens) return all_segments - def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + def encode(self, features: torch.Tensor) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = np.expand_dims(features, 0) + if features.ndim == 2: + features = features.unsqueeze(0) features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) @@ -904,12 +1510,19 @@ def get_prompt( previous_tokens: List[int], without_timestamps: bool = False, prefix: Optional[str] = None, + hotwords: Optional[str] = None, ) -> List[int]: prompt = [] - if previous_tokens: + if previous_tokens or (hotwords and not prefix): prompt.append(tokenizer.sot_prev) - prompt.extend(previous_tokens[-(self.max_length // 2 - 1):]) + if hotwords and not prefix: + hotwords_tokens = tokenizer.encode(" " + hotwords.strip()) + if len(hotwords_tokens) >= self.max_length // 2: + hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1] + prompt.extend(hotwords_tokens) + if previous_tokens: + prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :]) prompt.extend(tokenizer.sot_sequence) @@ -935,115 +1548,127 @@ def add_word_timestamps( prepend_punctuations: str, append_punctuations: str, last_speech_timestamp: float, - ) -> None: + ) -> float: if len(segments) == 0: return - text_tokens_per_segment = [ - [token for token in segment["tokens"] if token < tokenizer.eot] - for segment in segments - ] + text_tokens = [] + text_tokens_per_segment = [] + for segment in segments: + segment_tokens = [ + [token for token in subsegment["tokens"] if token < tokenizer.eot] + for subsegment in segment + ] + text_tokens.append(list(itertools.chain.from_iterable(segment_tokens))) + text_tokens_per_segment.append(segment_tokens) - text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) - alignment = self.find_alignment( + alignments = self.find_alignment( tokenizer, text_tokens, encoder_output, num_frames ) - word_durations = np.array([word["end"] - word["start"] for word in alignment]) - word_durations = word_durations[word_durations.nonzero()] - median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 - median_duration = min(0.7, float(median_duration)) - max_duration = median_duration * 2 - - # hack: truncate long words at sentence boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. - if len(word_durations) > 0: - sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries - # are not longer than twice the median word duration. - for i in range(1, len(alignment)): - if alignment[i]["end"] - alignment[i]["start"] > max_duration: - if alignment[i]["word"] in sentence_end_marks: - alignment[i]["end"] = alignment[i]["start"] + max_duration - elif alignment[i - 1]["word"] in sentence_end_marks: - alignment[i]["start"] = alignment[i]["end"] - max_duration - - merge_punctuations(alignment, prepend_punctuations, append_punctuations) - - time_offset = ( - segments[0]["seek"] - * self.feature_extractor.hop_length - / self.feature_extractor.sampling_rate - ) - - word_index = 0 - - for segment, text_tokens in zip(segments, text_tokens_per_segment): - saved_tokens = 0 - words = [] - - while word_index < len(alignment) and saved_tokens < len(text_tokens): - timing = alignment[word_index] + median_max_durations = [] + for alignment in alignments: + word_durations = np.array( + [word["end"] - word["start"] for word in alignment] + ) + word_durations = word_durations[word_durations.nonzero()] + median_duration = ( + np.median(word_durations) if len(word_durations) > 0 else 0.0 + ) + median_duration = min(0.7, float(median_duration)) + max_duration = median_duration * 2 - if timing["word"]: - words.append( - dict( - word=timing["word"], - start=round(time_offset + timing["start"], 2), - end=round(time_offset + timing["end"], 2), - probability=timing["probability"], + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries + # are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i]["end"] - alignment[i]["start"] > max_duration: + if alignment[i]["word"] in sentence_end_marks: + alignment[i]["end"] = alignment[i]["start"] + max_duration + elif alignment[i - 1]["word"] in sentence_end_marks: + alignment[i]["start"] = alignment[i]["end"] - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + median_max_durations.append((median_duration, max_duration)) + + for segment_idx, segment in enumerate(segments): + word_index = 0 + time_offset = segment[0]["start"] + median_duration, max_duration = median_max_durations[segment_idx] + for subsegment_idx, subsegment in enumerate(segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignments[segment_idx]) and saved_tokens < len( + text_tokens_per_segment[segment_idx][subsegment_idx] + ): + timing = alignments[segment_idx][word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], + ) ) - ) - saved_tokens += len(timing["tokens"]) - word_index += 1 + saved_tokens += len(timing["tokens"]) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0][ + "end" + ] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max( + words[1]["end"] / 2, words[1]["end"] - max_duration + ) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) - # hack: truncate long words at segment boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. - if len(words) > 0: - # ensure the first and second word after a pause is not longer than - # twice the median word duration. - if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( - words[0]["end"] - words[0]["start"] > max_duration - or ( - len(words) > 1 - and words[1]["end"] - words[0]["start"] > max_duration * 2 - ) - ): + # prefer the segment-level start timestamp if the first word is too long. if ( - len(words) > 1 - and words[1]["end"] - words[1]["start"] > max_duration + subsegment["start"] < words[0]["end"] + and subsegment["start"] - 0.5 > words[0]["start"] ): - boundary = max( - words[1]["end"] / 2, words[1]["end"] - max_duration + words[0]["start"] = max( + 0, + min(words[0]["end"] - median_duration, subsegment["start"]), ) - words[0]["end"] = words[1]["start"] = boundary - words[0]["start"] = max(0, words[0]["end"] - max_duration) - - # prefer the segment-level start timestamp if the first word is too long. - if ( - segment["start"] < words[0]["end"] - and segment["start"] - 0.5 > words[0]["start"] - ): - words[0]["start"] = max( - 0, min(words[0]["end"] - median_duration, segment["start"]) - ) - else: - segment["start"] = words[0]["start"] + else: + subsegment["start"] = words[0]["start"] - # prefer the segment-level end timestamp if the last word is too long. - if ( - segment["end"] > words[-1]["start"] - and segment["end"] + 0.5 < words[-1]["end"] - ): - words[-1]["end"] = max( - words[-1]["start"] + median_duration, segment["end"] - ) - else: - segment["end"] = words[-1]["end"] - - last_speech_timestamp = segment["end"] + # prefer the segment-level end timestamp if the last word is too long. + if ( + subsegment["end"] > words[-1]["start"] + and subsegment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, subsegment["end"] + ) + else: + subsegment["end"] = words[-1]["end"] - segment["words"] = words + last_speech_timestamp = subsegment["end"] + segments[segment_idx][subsegment_idx]["words"] = words + return last_speech_timestamp def find_alignment( self, @@ -1056,51 +1681,333 @@ def find_alignment( if len(text_tokens) == 0: return [] - result = self.model.align( + results = self.model.align( encoder_output, tokenizer.sot_sequence, - [text_tokens], + text_tokens, num_frames, median_filter_width=median_filter_width, - )[0] + ) + return_list = [] + for result, text_token in zip(results, text_tokens): + text_token_probs = result.text_token_probs + alignments = result.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + words, word_tokens = tokenizer.split_to_word_tokens( + text_token + [tokenizer.eot] + ) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] + word_boundaries = np.pad( + np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) + ) + if len(word_boundaries) <= 1: + return [] - text_token_probs = result.text_token_probs + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype( + bool + ) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] - alignments = result.alignments - text_indices = np.array([pair[0] for pair in alignments]) - time_indices = np.array([pair[1] for pair in alignments]) + return_list.append( + [ + dict( + word=word, + tokens=tokens, + start=start, + end=end, + probability=probability, + ) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + ) + return return_list - words, word_tokens = tokenizer.split_to_word_tokens( - text_tokens + [tokenizer.eot] + def generate_segment_batched( + self, + features: torch.Tensor, + tokenizer: Tokenizer, + options: dict, + ): + batch_size = features.shape[0] + all_tokens = [] + prompt_reset_since = 0 + + if options["initial_prompt"] is not None: + initial_prompt = " " + options["initial_prompt"].strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options["without_timestamps"], + prefix=options["prefix"], ) - if len(word_tokens) <= 1: - # return on eot only - # >>> np.pad([], (1, 0)) - # array([0.]) - # This results in crashes when we lookup jump_times with float, like - # IndexError: arrays used as indices must be of integer (or boolean) type - return [] - word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) - if len(word_boundaries) <= 1: - return [] - jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) - jump_times = time_indices[jumps] / self.tokens_per_second - start_times = jump_times[word_boundaries[:-1]] - end_times = jump_times[word_boundaries[1:]] - word_probabilities = [ - np.mean(text_token_probs[i:j]) - for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + encoder_output = self.encode(features) + + result = self.model.generate( + encoder_output, + [prompt] * batch_size, + beam_size=options["beam_size"], + patience=options["patience"], + length_penalty=options["length_penalty"], + max_length=self.max_length, + suppress_blank=options["suppress_blank"], + suppress_tokens=options["suppress_tokens"], + return_scores=True, + return_no_speech_prob=True, + ) + + output = [] + for res in result: + output.append({}) + # return scores + seq_len = len(res.sequences_ids[0]) + cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"]) + output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1) + + # return no speech prob + output[-1]["no_speech_prob"] = res.no_speech_prob + output[-1]["tokens"] = res.sequences_ids[0] + + return encoder_output, output + + def detect_language(self, audio: torch.Tensor): + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ + :, : self.feature_extractor.nb_max_frames ] + encoder_output = self.encode(pad_or_trim(segment)) + results = self.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + self.logger.info( + f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." + ) + all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]] + return language, language_probability, all_language_probs + + def detect_language_multi_segment( + self, audio: Union[str, BinaryIO, torch.Tensor], params: Optional[dict] = None + ): + """ + Detect language based on N highly-confident segments of a language. + """ + # The threshold is used to decide if the audio is silence or not. + # The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent, + # the audio is considered as silence. + if not params: + params = { + "multilingual": False, + "speech_percentage_threshold": 0.02, + "language_detection_segments": 4, + "vad_filter": True, + "vad_min_silence_duration": 2500, + "language_threshold": 0.7, + } + + if params.get("multilingual", False): + logging.warning( + "lang_id is not supported for multilingual audios, detecting the major language." + ) + + speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02) + language_threshold = params.get("language_threshold", 0.7) + num_detection_segments = params.get("language_detection_segments", 4) + vad_filter_enabled = params.get("vad_filter", True) + vad_params = dict( + min_silence_duration_ms=params.get("vad_min_silence_duration", 2500) + ) - return [ - dict( - word=word, tokens=tokens, start=start, end=end, probability=probability + if vad_filter_enabled: + vad_params = VadOptions(**vad_params) + + # decode audio if it is not decoded already + sampling_rate = self.feature_extractor.sampling_rate + if not isinstance(audio, torch.Tensor): + audio: torch.Tensor = decode_audio(audio, sampling_rate=sampling_rate) + + # calculate duration of audio as number of seconds + # audio.shape[0] is the number of samples in the audio + # sampling_rate is the number of samples per second + # if we divide the number of samples by the number of samples per second, + # we get the duration in seconds + duration = audio.shape[0] / sampling_rate + + # Check if vad is enabled, and collect voiced segments + if vad_filter_enabled: + # get chunks of audio that contain speech + speech_chunks = get_speech_timestamps(audio, vad_params) + # merge chunks of audio that contain speech into a single array + audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio = torch.cat(audio_chunks, dim=0) + + # calculate new duration of audio without silence + duration_vad = audio.shape[0] / sampling_rate + + logging.debug( + f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio" ) - for word, tokens, start, end, probability in zip( - words, word_tokens, start_times, end_times, word_probabilities + + # if the audio after VAD is less than 2% of the original audio, consider it as silence + if duration_vad / duration < speech_percentage_threshold: + return {"language_code": None, "language_confidence": 1.0} + + # update duration to be the duration after VAD + duration = duration_vad + + # if the duration of the audio is less than 1 second, consider it as silence + if duration < 1.0: + return {"language_code": None, "language_confidence": 1.0} + + # number of feature frames in 30 seconds of audio is 3000 + nb_max_frames = self.feature_extractor.nb_max_frames + + # extract features from audio with padding (default) + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + features = self.feature_extractor(audio, to_cpu=to_cpu) + + # number of segments in the audio + num_segments = features.shape[-1] // nb_max_frames + # more number of segments than possible with the duration of file + if num_detection_segments > num_segments: + logging.warning( + f"Lang ID: Can not have more segments, setting {num_segments} segments." ) - ] + num_detection_segments = num_segments + + # create a list of indices to randomly select segments from + indices = list(range(num_detection_segments)) + + # fix seed to get deterministic results + random.seed(0) + random.shuffle(indices) + + detected_languages = [] + all_language_probabilities = defaultdict(list) + confident_language_probabilities = defaultdict(list) + num_confident_segments_per_language = defaultdict(int) + + # Iterate over the randomly selected indices of the segments. + # + # For each segment, extract features and detect language. + # + # If the language is confident, add it to the list of confident segments for that language. + # + # If the number of confident segments for a language + # is greater than or equal to the number of detection segments, + # return the language and the average probability of the language. + # + # If we are unable to get sufficient number of confident predcitions, + # return the most frequently detected language with maximum probability. + # + # We need to get sufficient number of confident predictions per language, not in total. + + for i in indices: + segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] + try: + encoder_output = self.encode(pad_or_trim(segment_features)) + results = self.model.detect_language(encoder_output)[0] + + except ValueError as e: # or RuntimeError + logging.error(f"Inference error:{e}") + + # results is the list of classes (languages) and their probabilities (descending), + # for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...] + + # take top language token and probability + # and parse language token to strip out markers + # for eg: '<|de|>' -> 'de' + + language_token = results[0][0] + language = language_token[2:-2] + + language_probability = results[0][1] + + detected_languages.append(language) + all_language_probabilities[language].append(language_probability) + + # only consider if the language prediction is confident + if language_probability > language_threshold: + num_confident_segments_per_language[language] += 1 + + # Add language and probability to the list of languages when it is confident + confident_language_probabilities[language].append(language_probability) + + # return the language when sufficient number of confident segments is achieved + if ( + num_confident_segments_per_language[language] + >= num_detection_segments + ): + # Considering the average probability of only confident segments + mean = sum(confident_language_probabilities[language]) / len( + confident_language_probabilities[language] + ) + return { + "language_code": language, + "language_confidence": mean, + } + + # if we are unable to get sufficient number of confident predictions, + # return the most frequently detected language. + # if there is a tie, return the one with maximum average probability. + counter = Counter(detected_languages) + + # Define the key function to select frequent language with attached probabilities + def key_func(language): + # Calculate the frequency of the language + frequency = counter[language] + + # Calculate the average probability of the language + prob_avg = sum(all_language_probabilities[language]) / len( + all_language_probabilities[language] + ) + + return frequency, prob_avg + + if detected_languages: + # Use the key function to find the language with maximum frequency and probability + max_language = max(detected_languages, key=key_func) + max_probability = sum(all_language_probabilities[max_language]) / len( + all_language_probabilities[max_language] + ) + + # Do additional checks for silence for non-confident case + # calculate RMS amplitude and DC offset + dc_offset = audio.mean() + audio_minus_dc_offset = audio - dc_offset + is_silent = ( + torch.all(audio.abs() < 0.01) + or torch.sqrt(torch.mean(audio_minus_dc_offset**2)) < 0.01 + ) + + if is_silent: + return {"language_code": None, "language_confidence": 1.0} + + return { + "language_code": max_language, + "language_confidence": max_probability, + } + + # Language is not detected for any segment and none of prev conditions met + return {"language_code": None, "language_confidence": 1.0} def restore_speech_timestamps( @@ -1117,30 +2024,26 @@ def restore_speech_timestamps( # Ensure the word start and end times are resolved to the same chunk. middle = (word.start + word.end) / 2 chunk_index = ts_map.get_chunk_index(middle) - word = word._replace( - start=ts_map.get_original_time(word.start, chunk_index), - end=ts_map.get_original_time(word.end, chunk_index), - ) + word.start = ts_map.get_original_time(word.start, chunk_index) + word.end = ts_map.get_original_time(word.end, chunk_index) words.append(word) - segment = segment._replace( - start=words[0].start, - end=words[-1].end, - words=words, - ) + segment.start = words[0].start + segment.end = words[-1].end + segment.words = words else: - segment = segment._replace( - start=ts_map.get_original_time(segment.start), - end=ts_map.get_original_time(segment.end), - ) - + segment.start = ts_map.get_original_time(segment.start) + segment.end = ts_map.get_original_time(segment.end) return segments -def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: - segment = np.ascontiguousarray(segment) - segment = ctranslate2.StorageView.from_array(segment) +def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView: + segment = segment.contiguous() + segment = ctranslate2.StorageView.from_array( + segment if segment.is_cuda else segment.numpy() + ) # torch cpu tensors don't implement __array_interface__ + # https://github.com/pytorch/pytorch/issues/51156 return segment @@ -1151,15 +2054,16 @@ def get_compression_ratio(text: str) -> float: def get_suppressed_tokens( tokenizer: Tokenizer, - suppress_tokens: Optional[List[int]], + suppress_tokens: Tuple[int], ) -> Optional[List[int]]: - if not suppress_tokens or -1 in suppress_tokens: - return suppress_tokens - - suppress_tokens = list(suppress_tokens) + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" - # Ensure the following special tokens are suppressed when the user does - # not use the default set (-1). suppress_tokens.extend( [ tokenizer.transcribe, @@ -1170,7 +2074,7 @@ def get_suppressed_tokens( ] ) - return sorted(set(suppress_tokens)) + return tuple(sorted(set(suppress_tokens))) def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: @@ -1183,9 +2087,11 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> if previous["word"].startswith(" ") and previous["word"].strip() in prepended: # prepend it to the following word following["word"] = previous["word"] + following["word"] - following["tokens"] = previous["tokens"] + following["tokens"] + if "tokens" in alignment[0].keys(): + following["tokens"] = previous["tokens"] + following["tokens"] + previous["tokens"] = [] previous["word"] = "" - previous["tokens"] = [] + else: j = i i -= 1 @@ -1199,9 +2105,11 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> if not previous["word"].endswith(" ") and following["word"] in appended: # append it to the previous word previous["word"] = previous["word"] + following["word"] - previous["tokens"] = previous["tokens"] + following["tokens"] + if "tokens" in alignment[0].keys(): + previous["tokens"] = previous["tokens"] + following["tokens"] + following["tokens"] = [] following["word"] = "" - following["tokens"] = [] + else: i = j - j += 1 + j += 1 \ No newline at end of file From e275d3494351959e9310af50b9a56c5105d36221 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Mon, 18 Nov 2024 13:06:50 +0530 Subject: [PATCH 2/4] Remove pinned tiktoken version from server requirements Signed-off-by: makaveli10 --- requirements/server.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/server.txt b/requirements/server.txt index 8e10592..716bc13 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -9,5 +9,4 @@ scipy jiwer evaluate numpy<2 -tiktoken==0.8.0 openai-whisper==20240930 \ No newline at end of file From a6523b6b710d938f5b059a3758b279fd0ab58526 Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Tue, 19 Nov 2024 02:13:11 -0500 Subject: [PATCH 3/4] Minor fixes for better punctuations Signed-off-by: makaveli10 --- whisper_live/server.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/whisper_live/server.py b/whisper_live/server.py index e8f241c..9ed4a02 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -417,7 +417,7 @@ def __init__(self, client_uid, websocket): self.prev_out = '' self.t_start = None self.exit = False - self.same_output_threshold = 0 + self.same_output_count = 0 self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds self.transcript = [] @@ -794,6 +794,7 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli self.initial_prompt = initial_prompt self.vad_parameters = vad_parameters or {"onset": 0.5} self.no_speech_thresh = 0.45 + self.same_output_threshold = 10 device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda": @@ -1051,7 +1052,7 @@ def update_segments(self, segments, duration): last_segment = None # process complete segments - if len(segments) > 1: + if len(segments) > 1 and segments[-1].no_speech_prob <= self.no_speech_thresh: for i, s in enumerate(segments[:-1]): text_ = s.text self.text.append(text_) @@ -1065,7 +1066,7 @@ def update_segments(self, segments, duration): self.transcript.append(self.format_segment(start, end, text_, completed=True)) offset = min(duration, s.end) - # only process the segments if it satisfies the no_speech_thresh + # only process the last segment if it satisfies the no_speech_thresh if segments[-1].no_speech_prob <= self.no_speech_thresh: self.current_out += segments[-1].text last_segment = self.format_segment( @@ -1075,14 +1076,15 @@ def update_segments(self, segments, duration): completed=False ) - # if same incomplete segment is seen multiple times then update the offset - # and append the segment to the list if self.current_out.strip() == self.prev_out.strip() and self.current_out != '': - self.same_output_threshold += 1 + self.same_output_count += 1 + time.sleep(0.1) # wait for some voice activity just in case there is an unitended pause from the speaker for better punctuations. else: - self.same_output_threshold = 0 - - if self.same_output_threshold > 5: + self.same_output_count = 0 + + # if same incomplete segment is seen multiple times then update the offset + # and append the segment to the list + if self.same_output_count > self.same_output_threshold: if not len(self.text) or self.text[-1].strip().lower() != self.current_out.strip().lower(): self.text.append(self.current_out) self.transcript.append(self.format_segment( @@ -1093,7 +1095,7 @@ def update_segments(self, segments, duration): )) self.current_out = '' offset = duration - self.same_output_threshold = 0 + self.same_output_count = 0 last_segment = None else: self.prev_out = self.current_out From a1650eaa4ffcd6eb523df8d2cf96137470b9aa9e Mon Sep 17 00:00:00 2001 From: makaveli10 Date: Tue, 19 Nov 2024 13:31:55 +0530 Subject: [PATCH 4/4] Fix client tests to write srt file Signed-off-by: makaveli10 --- whisper_live/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whisper_live/client.py b/whisper_live/client.py index 15b6306..a748bb6 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -259,7 +259,9 @@ def write_srt_file(self, output_path="output.srt"): """ if self.server_backend == "faster_whisper": - if (self.last_segment) and self.transcript[-1]["text"] != self.last_segment["text"]: + if not self.transcript and self.last_segment is not None: + self.transcript.append(self.last_segment) + elif self.last_segment and self.transcript[-1]["text"] != self.last_segment["text"]: self.transcript.append(self.last_segment) utils.create_srt_file(self.transcript, output_path)