Skip to content

Commit

Permalink
Merge pull request #98 from makaveli10/change_model_size_param_name
Browse files Browse the repository at this point in the history
Change model size param name
  • Loading branch information
makaveli10 authored Jan 15, 2024
2 parents 076aebf + c810369 commit 0c01d7b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Audio-Transcription-Chrome/popup.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ document.addEventListener("DOMContentLoaded", function () {

// Send a message to the background script to start capturing
let host = "localhost";
let port = "5901";
let port = "9090";
const useCollaboraServer = useServerCheckbox.checked;
if (useCollaboraServer){
host = "transcription.kurg.org"
Expand Down
2 changes: 1 addition & 1 deletion Audio-Transcription-Firefox/popup.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ document.addEventListener("DOMContentLoaded", function() {

startButton.addEventListener("click", function() {
let host = "localhost";
let port = "5901";
let port = "9090";
const useCollaboraServer = useServerCheckbox.checked;

if (useCollaboraServer){
Expand Down
10 changes: 9 additions & 1 deletion run_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import argparse
from whisper_live.server import TranscriptionServer

if __name__ == "__main__":
server = TranscriptionServer()
server.run("0.0.0.0")
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default=None, help="Custom Faster Whisper Model")
args = parser.parse_args()
server.run(
"0.0.0.0",
9090,
custom_model_path=args.model_path
)
24 changes: 21 additions & 3 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ class Client:
INSTANCES = {}

def __init__(
self, host=None, port=None, is_multilingual=False, lang=None, translate=False, model_size="small"
self,
host=None,
port=None,
is_multilingual=False,
lang=None,
translate=False,
model_size="small",
use_custom_model=False
):
"""
Initializes a Client instance for audio recording and streaming to a server.
Expand Down Expand Up @@ -83,6 +90,8 @@ def __init__(
self.language = lang
self.model_size = model_size
self.server_error = False
self.use_custom_model = use_custom_model

if translate:
self.task = "translate"

Expand Down Expand Up @@ -221,6 +230,7 @@ def on_open(self, ws):
"language": self.language,
"task": self.task,
"model_size": self.model_size,
"use_custom_model": self.use_custom_model # if runnning your own server with a custom model
}
)
)
Expand Down Expand Up @@ -505,8 +515,16 @@ class TranscriptionClient:
transcription_client()
```
"""
def __init__(self, host, port, is_multilingual=False, lang=None, translate=False, model_size="small"):
self.client = Client(host, port, is_multilingual, lang, translate, model_size)
def __init__(self,
host,
port,
is_multilingual=False,
lang=None,
translate=False,
model_size="small",
use_custom_model=False
):
self.client = Client(host, port, is_multilingual, lang, translate, model_size, use_custom_model)

def __call__(self, audio=None, hls_url=None):
"""
Expand Down
41 changes: 31 additions & 10 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import websockets
import time
import threading
Expand All @@ -12,6 +13,8 @@
import torch
import numpy as np
import time
import functools

from whisper_live.transcriber import WhisperModel


Expand Down Expand Up @@ -58,7 +61,7 @@ def get_wait_time(self):

return wait_time / 60

def recv_audio(self, websocket):
def recv_audio(self, websocket, custom_model_path=None):
"""
Receive audio chunks from a client in an infinite loop.
Expand Down Expand Up @@ -95,16 +98,22 @@ def recv_audio(self, websocket):
websocket.close()
del websocket
return

# validate custom model
if options["use_custom_model"]:
if custom_model_path is None or not os.path.exists(custom_model_path):
options["use_custom_model"] = False

client = ServeClient(
websocket,
multilingual=options["multilingual"],
language=options["language"],
task=options["task"],
client_uid=options["uid"],
model_size=options["model_size"],
model_size_or_path=custom_model_path if options["use_custom_model"] else options["model_size"],
initial_prompt=options.get("initial_prompt"),
vad_parameters=options.get("vad_parameters")
vad_parameters=options.get("vad_parameters"),
use_custom_model=options["use_custom_model"]
)

self.clients[websocket] = client
Expand Down Expand Up @@ -137,15 +146,22 @@ def recv_audio(self, websocket):
del websocket
break

def run(self, host, port=9090):
def run(self, host, port=9090, custom_model_path=None):
"""
Run the transcription server.
Args:
host (str): The host address to bind the server.
port (int): The port number to bind the server.
"""
with serve(self.recv_audio, host, port) as server:
with serve(
functools.partial(
self.recv_audio,
custom_model_path=custom_model_path
),
host,
port
) as server:
server.serve_forever()


Expand Down Expand Up @@ -190,9 +206,10 @@ def __init__(
multilingual=False,
language=None,
client_uid=None,
model_size="small",
model_size_or_path="small",
initial_prompt=None,
vad_parameters=None
vad_parameters=None,
use_custom_model=False
):
"""
Initialize a ServeClient instance.
Expand All @@ -216,7 +233,11 @@ def __init__(
"tiny", "base", "small", "medium", "large-v2", "large-v3"
]
self.multilingual = multilingual
self.model_size = self.get_model_size(model_size)
if not use_custom_model:
self.model_size_or_path = self.get_model_size(model_size_or_path)
else:
self.model_size_or_path = model_size_or_path

self.language = language if self.multilingual else "en"
self.task = task
self.websocket = websocket
Expand All @@ -225,11 +246,11 @@ def __init__(

device = "cuda" if torch.cuda.is_available() else "cpu"

if self.model_size == None:
if self.model_size_or_path == None:
return

self.transcriber = WhisperModel(
self.model_size,
self.model_size_or_path,
device=device,
compute_type="int8" if device=="cpu" else "float16",
local_files_only=False,
Expand Down

0 comments on commit 0c01d7b

Please sign in to comment.