diff --git a/README.md b/README.md index 19bb230..9540026 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,9 @@ client = TranscriptionClient( use_vad=False, save_output_recording=True, # Only used for microphone input, False by Default output_recording_filename="./output_recording.wav", # Only used for microphone input + options={ + 'initial_prompt': None, #To add context replace None with any context for the model like this: 'Jane Doe context' + }, max_clients=4, max_connection_time=600 ) diff --git a/whisper_live/client.py b/whisper_live/client.py index 15b6306..5b340c6 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -31,6 +31,7 @@ def __init__( srt_file_path="output.srt", use_vad=True, log_transcription=True, + options=None, max_clients=4, max_connection_time=600, ): @@ -61,9 +62,11 @@ def __init__( self.last_segment = None self.last_received_segment = None self.log_transcription = log_transcription + self.options = options self.max_clients = max_clients self.max_connection_time = max_connection_time + if translate: self.task = "translate" @@ -204,6 +207,7 @@ def on_open(self, ws): "task": self.task, "model": self.model, "use_vad": self.use_vad, + "options": self.options, "max_clients": self.max_clients, "max_connection_time": self.max_connection_time, } @@ -687,12 +691,13 @@ def __init__( output_recording_filename="./output_recording.wav", output_transcription_path="./output.srt", log_transcription=True, + options=None, max_clients=4, max_connection_time=600, ): self.client = Client( host, port, lang, translate, model, srt_file_path=output_transcription_path, - use_vad=use_vad, log_transcription=log_transcription, max_clients=max_clients, + use_vad=use_vad, log_transcription=log_transcription, options=options, max_clients=max_clients, max_connection_time=max_connection_time ) diff --git a/whisper_live/server.py b/whisper_live/server.py index b68df5a..75c3d22 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -191,8 +191,8 @@ def initialize_client( task=options["task"], client_uid=options["uid"], model=options["model"], - initial_prompt=options.get("initial_prompt"), - vad_parameters=options.get("vad_parameters"), + initial_prompt=options["options"].get("initial_prompt"), + vad_parameters=options["options"].get("vad_parameters"), use_vad=self.use_vad, single_model=self.single_model, )