Skip to content

Commit

Permalink
Merge pull request #82 from hcljsq/main
Browse files Browse the repository at this point in the history
Add the `large-v3` model
  • Loading branch information
makaveli10 authored Jan 1, 2024
2 parents db2e0bb + 02793a9 commit 01665a5
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 7 deletions.
1 change: 1 addition & 0 deletions Audio-Transcription-Chrome/popup.html
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
<option value="small" selected>Small</option>
<option value="medium">Medium</option>
<option value="large-v2">Large-v2</option>
<option value="large-v3">Large-v3</option>
</select>
</div>
</body>
Expand Down
1 change: 1 addition & 0 deletions Audio-Transcription-Firefox/popup.html
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
<option value="small" selected>Small</option>
<option value="medium">Medium</option>
<option value="large-v2">Large-v2</option>
<option value="large-v3">Large-v3</option>
</select>
</div>
</body>
Expand Down
2 changes: 1 addition & 1 deletion requirements/server.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
PyAudio
faster-whisper==0.9.0
faster-whisper==0.10.0
--extra-index-url https://download.pytorch.org/whl/cu111
torch==1.10.1
torchaudio==0.10.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
),
install_requires=[
"PyAudio",
"faster-whisper==0.9.0",
"faster-whisper==0.10.0",
"torch",
"torchaudio",
"websockets",
Expand Down
4 changes: 2 additions & 2 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(
self.data = b""
self.frames = b""
self.model_sizes = [
"tiny", "base", "small", "medium", "large-v2"
"tiny", "base", "small", "medium", "large-v2", "large-v3"
]
self.multilingual = multilingual
self.model_size = self.get_model_size(model_size)
Expand Down Expand Up @@ -277,7 +277,7 @@ def get_model_size(self, model_size):
)
return None

if model_size == "large-v2":
if model_size in ["large-v2", "large-v3"]:
self.multilingual = True
return model_size

Expand Down
25 changes: 22 additions & 3 deletions whisper_live/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import os
import zlib
import json
from inspect import signature

from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -94,7 +96,7 @@ def __init__(
Args:
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
small, small.en, medium, medium.en, large-v1, large-v2, or large), a path to a converted
small, small.en, medium, medium.en, large-v1, large-v2, large-v3, or large), a path to a converted
model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
When a size or a model ID is configured, the converted model is downloaded
from the Hugging Face Hub.
Expand Down Expand Up @@ -144,7 +146,8 @@ def __init__(
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
)

self.feature_extractor = FeatureExtractor()
self.feat_kwargs = self._get_feature_kwargs(model_path)
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
self.num_samples_per_token = self.feature_extractor.hop_length * 2
self.frames_per_second = (
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
Expand All @@ -161,6 +164,22 @@ def supported_languages(self) -> List[str]:
"""The languages supported by the model."""
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]

def _get_feature_kwargs(self, model_path) -> dict:
preprocessor_config_file = os.path.join(model_path, "preprocessor_config.json")
config = {}
if os.path.isfile(preprocessor_config_file):
try:
with open(preprocessor_config_file, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
config = {k: v for k, v in config.items() if k in valid_keys}
except json.JSONDecodeError as e:
self.logger.warning(
"Could not load preprocessor_config.json: %s", str(e)
)

return config

def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
Expand Down Expand Up @@ -914,7 +933,7 @@ def find_alignment(
words, word_tokens, start_times, end_times, word_probabilities
)
]

def destroy(self):
del self.model

Expand Down

0 comments on commit 01665a5

Please sign in to comment.